diff --git a/benches/main.rs b/benches/main.rs index c12c003..f2ecc33 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -1,12 +1,13 @@ use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criterion}; -use datafusion::common::ScalarValue; +use datafusion::arrow::datatypes::DataType; use datafusion::logical_expr::ColumnarValue; -use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf}; +use datafusion::{common::ScalarValue, logical_expr::ScalarFunctionArgs}; +use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_top_level_sorted_udf, json_get_str_udf}; fn bench_json_contains(b: &mut Bencher) { let json_contains = json_contains_udf(); - let args = &[ + let args = vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some( r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), ))), @@ -14,7 +15,15 @@ fn bench_json_contains(b: &mut Bencher) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - b.iter(|| json_contains.invoke_batch(args, 1).unwrap()); + b.iter(|| { + json_contains + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: 1, + return_type: &DataType::Boolean, + }) + .unwrap() + }); } fn bench_json_get_str(b: &mut Bencher) { @@ -27,12 +36,89 @@ fn bench_json_get_str(b: &mut Bencher) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - b.iter(|| json_get_str.invoke_batch(args, 1).unwrap()); + b.iter(|| { + json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::Utf8, + }) + .unwrap() + }); +} + +fn make_json_negative_testcase() -> String { + // build a json with keys "b1", "b2" ... "b100", each with a large value ("a" repeated 1024 times) + let kvs = (0..100) + .map(|i| format!(r#""b{}": "{}""#, i, "a".repeat(1024))) + .collect::>() + .join(","); + format!(r"{{ {kvs} }}") +} + +fn bench_json_get_str_negative(b: &mut Bencher) { + let json_get_str = json_get_str_udf(); + let args = &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(make_json_negative_testcase()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), // lexicographically less than "b1" + ]; + + b.iter(|| { + json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::Utf8, + }) + .unwrap() + }); +} + +fn bench_json_get_str_sorted(b: &mut Bencher) { + let json_get_str = json_get_str_top_level_sorted_udf(); + let args = &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), + ]; + + b.iter(|| { + json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::Utf8, + }) + .unwrap() + }); +} + +fn bench_json_get_str_sorted_negative(b: &mut Bencher) { + let json_get_str = json_get_str_top_level_sorted_udf(); + let args = &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(make_json_negative_testcase()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), // lexicographically less than "b1" + ]; + + b.iter(|| { + json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::Utf8, + }) + .unwrap() + }); } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("json_contains", bench_json_contains); c.bench_function("json_get_str", bench_json_get_str); + c.bench_function("json_get_str_negative", bench_json_get_str_negative); + c.bench_function("json_get_str_sorted", bench_json_get_str_sorted); + c.bench_function("json_get_str_sorted_negative", bench_json_get_str_sorted_negative); } criterion_group!(benches, criterion_benchmark); diff --git a/src/common.rs b/src/common.rs index 66505cb..5bfd079 100644 --- a/src/common.rs +++ b/src/common.rs @@ -511,7 +511,11 @@ fn wrap_as_large_dictionary(original: &dyn AnyDictionaryArray, new_values: Array DictionaryArray::new(keys, new_values) } -pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> { +pub fn jiter_json_find<'j>( + opt_json: Option<&'j str>, + path: &[JsonPath], + mut sorted: Sortedness, +) -> Option<(Jiter<'j>, Peek)> { let json_str = opt_json?; let mut jiter = Jiter::new(json_str.as_bytes()); let mut peek = jiter.peek().ok()?; @@ -521,6 +525,11 @@ pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Opti let mut next_key = jiter.known_object().ok()??; while next_key != *key { + if next_key > *key && matches!(sorted, Sortedness::Recursive | Sortedness::TopLevel) { + // The current object is sorted and next_key is lexicographically greater than key + // we are looking for, so we can early stop here. + return None; + } jiter.next_skip().ok()?; next_key = jiter.next_key().ok()??; } @@ -541,6 +550,9 @@ pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Opti return None; } } + if sorted == Sortedness::TopLevel { + sorted = Sortedness::Unspecified; + } } Some((jiter, peek)) } @@ -585,3 +597,35 @@ fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> Pr } PrimitiveArray::new(keys.values().clone(), Some(null_mask.into())) } + +/// Information about the sortedness of a JSON object. +/// This is used to optimize key lookups by early stopping when the key we are looking for is +/// lexicographically greater than the current key and the object is known to be sorted. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) enum Sortedness { + /// No guarantees about the order of the elements. + Unspecified, + /// Only the outermost object is known to be sorted. + /// If the outermost item is not an object, this is equivalent to `Unspecified`. + TopLevel, + /// All objects are known to be sorted, including objects nested within arrays. + Recursive, +} + +impl Sortedness { + pub(crate) fn iter() -> impl Iterator { + [Sortedness::Unspecified, Sortedness::TopLevel, Sortedness::Recursive] + .iter() + .copied() + } +} + +impl Sortedness { + pub(crate) fn function_name_suffix(self) -> &'static str { + match self { + Sortedness::Unspecified => "", + Sortedness::TopLevel => "_top_level_sorted", + Sortedness::Recursive => "_recursive_sorted", + } + } +} diff --git a/src/common_macros.rs b/src/common_macros.rs index a2c6cd0..8435ffa 100644 --- a/src/common_macros.rs +++ b/src/common_macros.rs @@ -15,7 +15,7 @@ /// /// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl macro_rules! make_udf_function { - ($udf_impl:ty, $expr_fn_name:ident, $($arg:ident)*, $doc:expr) => { + ($udf_impl:ty, $expr_fn_name:ident, $($arg:ident)*, $doc:expr, $sorted:expr) => { paste::paste! { #[doc = $doc] #[must_use] pub fn $expr_fn_name($($arg: datafusion::logical_expr::Expr),*) -> datafusion::logical_expr::Expr { @@ -37,7 +37,7 @@ macro_rules! make_udf_function { [< STATIC_ $expr_fn_name:upper >] .get_or_init(|| { std::sync::Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl( - <$udf_impl>::default(), + <$udf_impl>::new($sorted), )) }) .clone() diff --git a/src/json_as_text.rs b/src/json_as_text.rs index ce73375..11d7037 100644 --- a/src/json_as_text.rs +++ b/src/json_as_text.rs @@ -4,30 +4,51 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, StringArray, StringBuilder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common::{ + get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness, +}; use crate::common_macros::make_udf_function; make_udf_function!( JsonAsText, json_as_text, json_data path, - r#"Get any value from a JSON string by its "path", represented as a string"# + r#"Get any value from a JSON string by its "path", represented as a string"#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonAsText, + json_as_text_top_level_sorted, + json_data path, + r#"Get any value from a JSON string by its "path", represented as a string; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonAsText, + json_as_text_recursive_sorted, + json_data path, + r#"Get any value from a JSON string by its "path", represented as a string; assumes all json object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonAsText { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } -impl Default for JsonAsText { - fn default() -> Self { +impl JsonAsText { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_as_text".to_string()], + aliases: [format!("json_as_text{}", sorted.function_name_suffix())], + sorted, } } } @@ -49,8 +70,8 @@ impl ScalarUDFImpl for JsonAsText { return_type_check(arg_types, self.name(), DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_as_text) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, |args, path| jiter_json_as_text(args, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -82,8 +103,8 @@ impl InvokeResult for StringArray { } } -fn jiter_json_as_text(opt_json: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { +fn jiter_json_as_text(opt_json: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) { match peek { Peek::Null => { jiter.known_null()?; diff --git a/src/json_contains.rs b/src/json_contains.rs index 0eda035..13de465 100644 --- a/src/json_contains.rs +++ b/src/json_contains.rs @@ -5,7 +5,7 @@ use datafusion::arrow::array::BooleanBuilder; use datafusion::arrow::datatypes::DataType; use datafusion::common::arrow::array::{ArrayRef, BooleanArray}; use datafusion::common::{plan_err, Result, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use crate::common::{invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; use crate::common_macros::make_udf_function; @@ -14,20 +14,39 @@ make_udf_function!( JsonContains, json_contains, json_data path, - r#"Does the key/index exist within the JSON value as the specified "path"?"# + r#"Does the key/index exist within the JSON value as the specified "path"?"#, + crate::common::Sortedness::Unspecified +); + +make_udf_function!( + JsonContains, + json_contains_top_level_sorted, + json_data path, + r#"Does the key/index exist within the JSON value as the specified "path"; assumes the JSON string's top level object's keys are sorted?"#, + crate::common::Sortedness::TopLevel +); + +make_udf_function!( + JsonContains, + json_contains_recursive_sorted, + json_data path, + r#"Does the key/index exist within the JSON value as the specified "path"; assumes all json object's keys are sorted?"#, + crate::common::Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonContains { signature: Signature, aliases: [String; 1], + sorted: crate::common::Sortedness, } -impl Default for JsonContains { - fn default() -> Self { +impl JsonContains { + pub fn new(sorted: crate::common::Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_contains".to_string()], + aliases: [format!("json_contains{}", sorted.function_name_suffix())], + sorted, } } } @@ -53,8 +72,8 @@ impl ScalarUDFImpl for JsonContains { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - invoke::(args, jiter_json_contains) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + invoke::(&args.args, |json, path| jiter_json_contains(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -88,6 +107,10 @@ impl InvokeResult for BooleanArray { } #[allow(clippy::unnecessary_wraps)] -fn jiter_json_contains(json_data: Option<&str>, path: &[JsonPath]) -> Result { - Ok(jiter_json_find(json_data, path).is_some()) +fn jiter_json_contains( + json_data: Option<&str>, + path: &[JsonPath], + sorted: crate::common::Sortedness, +) -> Result { + Ok(jiter_json_find(json_data, path, sorted).is_some()) } diff --git a/src/json_get.rs b/src/json_get.rs index c0ba131..086fcf1 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -5,11 +5,12 @@ use datafusion::arrow::array::ArrayRef; use datafusion::arrow::array::UnionArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion::scalar::ScalarValue; use jiter::{Jiter, NumberAny, NumberInt, Peek}; use crate::common::InvokeResult; +use crate::common::Sortedness; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; use crate::common_union::{JsonUnion, JsonUnionField}; @@ -18,22 +19,39 @@ make_udf_function!( JsonGet, json_get, json_data path, - r#"Get a value from a JSON string by its "path""# + r#"Get a value from a JSON string by its "path""#, + Sortedness::Unspecified ); -// build_typed_get!(JsonGet, "json_get", Union, Float64Array, jiter_json_get_float); +make_udf_function!( + JsonGet, + json_get_top_level_sorted, + json_data path, + r#"Get a value from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGet, + json_get_recursive_sorted, + json_data path, + r#"Get a value from a JSON string by its "path"; assumes all object's keys are sorted."#, + Sortedness::Recursive +); #[derive(Debug)] pub(super) struct JsonGet { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } -impl Default for JsonGet { - fn default() -> Self { +impl JsonGet { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get".to_string()], + aliases: [format!("json_get{}", sorted.function_name_suffix())], + sorted, } } } @@ -55,8 +73,8 @@ impl ScalarUDFImpl for JsonGet { return_type_check(arg_types, self.name(), JsonUnion::data_type()) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_union) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, |json, path| jiter_json_get_union(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -93,8 +111,12 @@ impl InvokeResult for JsonUnion { } } -fn jiter_json_get_union(opt_json: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { +fn jiter_json_get_union( + opt_json: Option<&str>, + path: &[JsonPath], + sorted: Sortedness, +) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) { build_union(&mut jiter, peek) } else { get_err!() diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs index 2bae88b..75975d9 100644 --- a/src/json_get_bool.rs +++ b/src/json_get_bool.rs @@ -3,30 +3,49 @@ use std::any::Any; use datafusion::arrow::array::BooleanArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath, Sortedness}; use crate::common_macros::make_udf_function; make_udf_function!( JsonGetBool, json_get_bool, json_data path, - r#"Get an boolean value from a JSON string by its "path""# + r#"Get an boolean value from a JSON string by its "path""#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonGetBool, + json_get_bool_top_level_sorted, + json_data path, + r#"Get an boolean value from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGetBool, + json_get_bool_recursive_sorted, + json_data path, + r#"Get an boolean value from a JSON string by its "path"; assumes all json object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonGetBool { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } -impl Default for JsonGetBool { - fn default() -> Self { +impl JsonGetBool { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_bool".to_string()], + aliases: [format!("json_get_bool{}", sorted.function_name_suffix())], + sorted, } } } @@ -48,8 +67,8 @@ impl ScalarUDFImpl for JsonGetBool { return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_bool) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, |json, path| jiter_json_get_bool(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -57,8 +76,8 @@ impl ScalarUDFImpl for JsonGetBool { } } -fn jiter_json_get_bool(json_data: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { +fn jiter_json_get_bool(json_data: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path, sorted) { match peek { Peek::True | Peek::False => Ok(jiter.known_bool(peek)?), _ => get_err!(), diff --git a/src/json_get_float.rs b/src/json_get_float.rs index 24b00a5..c734df2 100644 --- a/src/json_get_float.rs +++ b/src/json_get_float.rs @@ -4,30 +4,51 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, Float64Array, Float64Builder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberAny, Peek}; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common::{ + get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness, +}; use crate::common_macros::make_udf_function; make_udf_function!( JsonGetFloat, json_get_float, json_data path, - r#"Get a float value from a JSON string by its "path""# + r#"Get a float value from a JSON string by its "path""#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonGetFloat, + json_get_float_top_level_sorted, + json_data path, + r#"Get an float value from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGetFloat, + json_get_float_recursive_sorted, + json_data path, + r#"Get an float value from a JSON string by its "path"; assumes all object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonGetFloat { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } -impl Default for JsonGetFloat { - fn default() -> Self { +impl JsonGetFloat { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_float".to_string()], + aliases: [format!("json_get_float{}", sorted.function_name_suffix())], + sorted, } } } @@ -49,8 +70,10 @@ impl ScalarUDFImpl for JsonGetFloat { return_type_check(arg_types, self.name(), DataType::Float64) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_float) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, |json_data, path| { + jiter_json_get_float(json_data, path, self.sorted) + }) } fn aliases(&self) -> &[String] { @@ -83,8 +106,8 @@ impl InvokeResult for Float64Array { } } -fn jiter_json_get_float(json_data: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { +fn jiter_json_get_float(json_data: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path, sorted) { match peek { // numbers are represented by everything else in peek, hence doing it this way Peek::Null diff --git a/src/json_get_int.rs b/src/json_get_int.rs index a6a2b90..fefa0b6 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -4,30 +4,51 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, Int64Array, Int64Builder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberInt, Peek}; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common::{ + get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness, +}; use crate::common_macros::make_udf_function; make_udf_function!( JsonGetInt, json_get_int, json_data path, - r#"Get an integer value from a JSON string by its "path""# + r#"Get an integer value from a JSON string by its "path""#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonGetInt, + json_get_int_top_level_sorted, + json_data path, + r#"Get an integer value from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGetInt, + json_get_int_recursive_sorted, + json_data path, + r#"Get an integer value from a JSON string by its "path"; assumes all json object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonGetInt { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } -impl Default for JsonGetInt { - fn default() -> Self { +impl JsonGetInt { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_int".to_string()], + aliases: [format!("json_get_int{}", sorted.function_name_suffix())], + sorted, } } } @@ -49,8 +70,8 @@ impl ScalarUDFImpl for JsonGetInt { return_type_check(arg_types, self.name(), DataType::Int64) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_int) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, |json, path| jiter_json_get_int(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -83,8 +104,8 @@ impl InvokeResult for Int64Array { } } -fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { +fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path, sorted) { match peek { // numbers are represented by everything else in peek, hence doing it this way Peek::Null diff --git a/src/json_get_json.rs b/src/json_get_json.rs index c5890b4..e9fe23a 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -3,29 +3,48 @@ use std::any::Any; use datafusion::arrow::array::StringArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath, Sortedness}; use crate::common_macros::make_udf_function; make_udf_function!( JsonGetJson, json_get_json, json_data path, - r#"Get a nested raw JSON string from a JSON string by its "path""# + r#"Get a nested raw JSON string from a JSON string by its "path""#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonGetJson, + json_get_json_top_level_sorted, + json_data path, + r#"Get a nested raw JSON string from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGetJson, + json_get_json_recursive_sorted, + json_data path, + r#"Get a nested raw JSON string from a JSON string by its "path"; assumes all json object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonGetJson { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } -impl Default for JsonGetJson { - fn default() -> Self { +impl JsonGetJson { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_json".to_string()], + aliases: [format!("json_get_json{}", sorted.function_name_suffix())], + sorted, } } } @@ -47,8 +66,8 @@ impl ScalarUDFImpl for JsonGetJson { return_type_check(arg_types, self.name(), DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_json) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, |json, path| jiter_json_get_json(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -56,8 +75,8 @@ impl ScalarUDFImpl for JsonGetJson { } } -fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { +fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) { let start = jiter.current_index(); jiter.known_skip(peek)?; let object_slice = jiter.slice_to_current(start); diff --git a/src/json_get_str.rs b/src/json_get_str.rs index 633321f..5538c46 100644 --- a/src/json_get_str.rs +++ b/src/json_get_str.rs @@ -3,30 +3,49 @@ use std::any::Any; use datafusion::arrow::array::StringArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath, Sortedness}; use crate::common_macros::make_udf_function; make_udf_function!( JsonGetStr, json_get_str, json_data path, - r#"Get a string value from a JSON string by its "path""# + r#"Get a string value from a JSON string by its "path""#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonGetStr, + json_get_str_top_level_sorted, + json_data path, + r#"Get a string value from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGetStr, + json_get_str_recursive_sorted, + json_data path, + r#"Get a string value from a JSON string by its "path"; assumes all json object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonGetStr { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } -impl Default for JsonGetStr { - fn default() -> Self { +impl JsonGetStr { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_str".to_string()], + aliases: [format!("json_get_str{}", sorted.function_name_suffix())], + sorted, } } } @@ -48,8 +67,8 @@ impl ScalarUDFImpl for JsonGetStr { return_type_check(arg_types, self.name(), DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_str) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, |json, path| jiter_json_get_str(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -57,8 +76,8 @@ impl ScalarUDFImpl for JsonGetStr { } } -fn jiter_json_get_str(json_data: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { +fn jiter_json_get_str(json_data: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path, sorted) { match peek { Peek::String => Ok(jiter.known_str()?.to_owned()), _ => get_err!(), diff --git a/src/json_length.rs b/src/json_length.rs index 8d0bfd3..7da6ab8 100644 --- a/src/json_length.rs +++ b/src/json_length.rs @@ -4,30 +4,54 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, UInt64Array, UInt64Builder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common::{ + get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness, +}; use crate::common_macros::make_udf_function; make_udf_function!( JsonLength, json_length, json_data path, - r"Get the length of the array or object at the given path." + r"Get the length of the array or object at the given path.", + Sortedness::Unspecified +); + +make_udf_function!( + JsonLength, + json_length_top_level_sorted, + json_data path, + r"Get the length of the array or object at the given path; assumes the JSON object's keys are sorted.", + Sortedness::TopLevel +); + +make_udf_function!( + JsonLength, + json_length_recursive_sorted, + json_data path, + r"Get the length of the array or object at the given path; assumes all object's keys are sorted.", + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonLength { signature: Signature, aliases: [String; 2], + sorted: Sortedness, } -impl Default for JsonLength { - fn default() -> Self { +impl JsonLength { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_length".to_string(), "json_len".to_string()], + aliases: [ + format!("json_length{}", sorted.function_name_suffix()), + format!("json_len{}", sorted.function_name_suffix()), + ], + sorted, } } } @@ -49,8 +73,8 @@ impl ScalarUDFImpl for JsonLength { return_type_check(arg_types, self.name(), DataType::UInt64) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_length) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, |json, path| jiter_json_length(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -83,8 +107,8 @@ impl InvokeResult for UInt64Array { } } -fn jiter_json_length(opt_json: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { +fn jiter_json_length(opt_json: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) { match peek { Peek::Array => { let mut peek_opt = jiter.known_array()?; diff --git a/src/json_object_keys.rs b/src/json_object_keys.rs index 04cac0e..349e75d 100644 --- a/src/json_object_keys.rs +++ b/src/json_object_keys.rs @@ -4,30 +4,54 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, ListBuilder, StringBuilder}; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common::{ + get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness, +}; use crate::common_macros::make_udf_function; make_udf_function!( JsonObjectKeys, json_object_keys, json_data path, - r"Get the keys of a JSON object as an array." + r"Get the keys of a JSON object as an array.", + Sortedness::Unspecified +); + +make_udf_function!( + JsonObjectKeys, + json_keys_sorted, + json_data path, + r"Get the keys of a JSON object as an array; assumes the JSON object's keys are sorted.", + Sortedness::TopLevel +); + +make_udf_function!( + JsonObjectKeys, + json_keys_recursive_sorted, + json_data path, + r"Get the keys of a JSON object as an array; assumes all object's keys are sorted.", + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonObjectKeys { signature: Signature, aliases: [String; 2], + sorted: Sortedness, } -impl Default for JsonObjectKeys { - fn default() -> Self { +impl JsonObjectKeys { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_object_keys".to_string(), "json_keys".to_string()], + aliases: [ + format!("json_object_keys{}", sorted.function_name_suffix()), + format!("json_keys{}", sorted.function_name_suffix()), + ], + sorted, } } } @@ -53,8 +77,10 @@ impl ScalarUDFImpl for JsonObjectKeys { ) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_object_keys) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, |opt_json, path| { + jiter_json_object_keys(opt_json, path, self.sorted) + }) } fn aliases(&self) -> &[String] { @@ -106,8 +132,12 @@ fn keys_to_scalar(opt_keys: Option>) -> ScalarValue { ScalarValue::List(Arc::new(array)) } -fn jiter_json_object_keys(opt_json: Option<&str>, path: &[JsonPath]) -> Result, GetError> { - if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { +fn jiter_json_object_keys( + opt_json: Option<&str>, + path: &[JsonPath], + sorted: Sortedness, +) -> Result, GetError> { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) { match peek { Peek::Object => { let mut opt_key = jiter.known_object()?; diff --git a/src/lib.rs b/src/lib.rs index cb0f25a..10fabbc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,29 +23,43 @@ mod rewrite; pub use common_union::{JsonUnionEncoder, JsonUnionValue}; pub mod functions { - pub use crate::json_as_text::json_as_text; - pub use crate::json_contains::json_contains; - pub use crate::json_get::json_get; - pub use crate::json_get_bool::json_get_bool; - pub use crate::json_get_float::json_get_float; - pub use crate::json_get_int::json_get_int; - pub use crate::json_get_json::json_get_json; - pub use crate::json_get_str::json_get_str; - pub use crate::json_length::json_length; - pub use crate::json_object_keys::json_object_keys; + pub use crate::json_as_text::{json_as_text, json_as_text_recursive_sorted, json_as_text_top_level_sorted}; + pub use crate::json_contains::{json_contains, json_contains_recursive_sorted, json_contains_top_level_sorted}; + pub use crate::json_get::{json_get, json_get_recursive_sorted, json_get_top_level_sorted}; + pub use crate::json_get_bool::{json_get_bool, json_get_bool_recursive_sorted, json_get_bool_top_level_sorted}; + pub use crate::json_get_float::{json_get_float, json_get_float_recursive_sorted, json_get_float_top_level_sorted}; + pub use crate::json_get_int::{json_get_int, json_get_int_recursive_sorted, json_get_int_top_level_sorted}; + pub use crate::json_get_json::{json_get_json, json_get_json_recursive_sorted, json_get_json_top_level_sorted}; + pub use crate::json_get_str::{json_get_str, json_get_str_recursive_sorted, json_get_str_top_level_sorted}; + pub use crate::json_length::{json_length, json_length_recursive_sorted, json_length_top_level_sorted}; + pub use crate::json_object_keys::{json_keys_recursive_sorted, json_keys_sorted, json_object_keys}; } pub mod udfs { - pub use crate::json_as_text::json_as_text_udf; - pub use crate::json_contains::json_contains_udf; - pub use crate::json_get::json_get_udf; - pub use crate::json_get_bool::json_get_bool_udf; - pub use crate::json_get_float::json_get_float_udf; - pub use crate::json_get_int::json_get_int_udf; - pub use crate::json_get_json::json_get_json_udf; - pub use crate::json_get_str::json_get_str_udf; - pub use crate::json_length::json_length_udf; - pub use crate::json_object_keys::json_object_keys_udf; + pub use crate::json_as_text::{ + json_as_text_recursive_sorted_udf, json_as_text_top_level_sorted_udf, json_as_text_udf, + }; + pub use crate::json_contains::{ + json_contains_recursive_sorted_udf, json_contains_top_level_sorted_udf, json_contains_udf, + }; + pub use crate::json_get::{json_get_recursive_sorted_udf, json_get_top_level_sorted_udf, json_get_udf}; + pub use crate::json_get_bool::{ + json_get_bool_recursive_sorted_udf, json_get_bool_top_level_sorted_udf, json_get_bool_udf, + }; + pub use crate::json_get_float::{ + json_get_float_recursive_sorted_udf, json_get_float_top_level_sorted_udf, json_get_float_udf, + }; + pub use crate::json_get_int::{ + json_get_int_recursive_sorted_udf, json_get_int_top_level_sorted_udf, json_get_int_udf, + }; + pub use crate::json_get_json::{ + json_get_json_recursive_sorted_udf, json_get_json_top_level_sorted_udf, json_get_json_udf, + }; + pub use crate::json_get_str::{ + json_get_str_recursive_sorted_udf, json_get_str_top_level_sorted_udf, json_get_str_udf, + }; + pub use crate::json_length::{json_length_recursive_sorted_udf, json_length_top_level_sorted_udf, json_length_udf}; + pub use crate::json_object_keys::{json_keys_recursive_sorted_udf, json_keys_sorted_udf, json_object_keys_udf}; } /// Register all JSON UDFs, and [`rewrite::JsonFunctionRewriter`] with the provided [`FunctionRegistry`]. @@ -60,15 +74,35 @@ pub mod udfs { pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ json_get::json_get_udf(), + json_get::json_get_top_level_sorted_udf(), + json_get::json_get_recursive_sorted_udf(), json_get_bool::json_get_bool_udf(), + json_get_bool::json_get_bool_top_level_sorted_udf(), + json_get_bool::json_get_bool_recursive_sorted_udf(), json_get_float::json_get_float_udf(), + json_get_float::json_get_float_top_level_sorted_udf(), + json_get_float::json_get_float_recursive_sorted_udf(), json_get_int::json_get_int_udf(), + json_get_int::json_get_int_top_level_sorted_udf(), + json_get_int::json_get_int_recursive_sorted_udf(), json_get_json::json_get_json_udf(), + json_get_json::json_get_json_top_level_sorted_udf(), + json_get_json::json_get_json_recursive_sorted_udf(), json_as_text::json_as_text_udf(), + json_as_text::json_as_text_top_level_sorted_udf(), + json_as_text::json_as_text_recursive_sorted_udf(), json_get_str::json_get_str_udf(), + json_get_str::json_get_str_top_level_sorted_udf(), + json_get_str::json_get_str_recursive_sorted_udf(), json_contains::json_contains_udf(), + json_contains::json_contains_top_level_sorted_udf(), + json_contains::json_contains_recursive_sorted_udf(), json_length::json_length_udf(), + json_length::json_length_top_level_sorted_udf(), + json_length::json_length_recursive_sorted_udf(), json_object_keys::json_object_keys_udf(), + json_object_keys::json_keys_sorted_udf(), + json_object_keys::json_keys_recursive_sorted_udf(), ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/src/rewrite.rs b/src/rewrite.rs index 4161154..ce05d51 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::sync::LazyLock; use datafusion::arrow::datatypes::DataType; use datafusion::common::config::ConfigOptions; @@ -11,8 +12,12 @@ use datafusion::logical_expr::expr_rewriter::FunctionRewrite; use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; use datafusion::logical_expr::sqlparser::ast::BinaryOperator; use datafusion::logical_expr::ScalarUDF; +use datafusion::logical_expr::ScalarUDFImpl; use datafusion::scalar::ScalarValue; +use crate::common::Sortedness; +use crate::json_get::JsonGet; + #[derive(Debug)] pub(crate) struct JsonFunctionRewriter; @@ -31,11 +36,14 @@ impl FunctionRewrite for JsonFunctionRewriter { } } +static JSON_GET_FUNC_NAMES: LazyLock> = + LazyLock::new(|| Sortedness::iter().map(|s| JsonGet::new(s).name().to_string()).collect()); + /// This replaces `get_json(foo, bar)::int` with `json_get_int(foo, bar)` so the JSON function can take care of /// extracting the right value type from JSON without the need to materialize the JSON union. fn optimise_json_get_cast(cast: &Cast) -> Option> { let scalar_func = extract_scalar_function(&cast.expr)?; - if scalar_func.func.name() != "json_get" { + if !JSON_GET_FUNC_NAMES.contains(&scalar_func.func.name().to_owned()) { return None; } let func = match &cast.data_type { @@ -53,18 +61,24 @@ fn optimise_json_get_cast(cast: &Cast) -> Option> { }))) } +static JSON_FUNCTION_NAMES: LazyLock> = LazyLock::new(|| { + Sortedness::iter() + .flat_map(|s| { + [ + crate::json_get::JsonGet::new(s).name().to_string(), + crate::json_get_bool::JsonGetBool::new(s).name().to_string(), + crate::json_get_float::JsonGetFloat::new(s).name().to_string(), + crate::json_get_int::JsonGetInt::new(s).name().to_string(), + crate::json_get_str::JsonGetStr::new(s).name().to_string(), + crate::json_as_text::JsonAsText::new(s).name().to_string(), + ] + }) + .collect() +}); + // Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')` fn unnest_json_calls(func: &ScalarFunction) -> Option> { - if !matches!( - func.func.name(), - "json_get" - | "json_get_bool" - | "json_get_float" - | "json_get_int" - | "json_get_json" - | "json_get_str" - | "json_as_text" - ) { + if !JSON_FUNCTION_NAMES.contains(&func.func.name().to_owned()) { return None; } let mut outer_args_iter = func.args.iter(); diff --git a/tests/main.rs b/tests/main.rs index 0e65338..215fc70 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -5,7 +5,7 @@ use datafusion::arrow::datatypes::{Field, Int64Type, Int8Type, Schema}; use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType}; use datafusion::assert_batches_eq; use datafusion::common::ScalarValue; -use datafusion::logical_expr::ColumnarValue; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion::prelude::SessionContext; use datafusion_functions_json::udfs::json_get_str_udf; use utils::{create_context, display_val, logical_plan, run_query, run_query_dict, run_query_large, run_query_params}; @@ -72,7 +72,7 @@ async fn test_json_get_union() { "+------------------+--------------------------------------+", "| object_foo | {str=abc} |", "| object_foo_array | {array=[1]} |", - "| object_foo_obj | {object={}} |", + "| object_foo_obj | {object={\"bar\": 1}} |", "| object_foo_null | {null=} |", "| object_bar | {null=} |", "| list_foo | {null=} |", @@ -147,6 +147,126 @@ async fn test_json_get_str() { assert_batches_eq!(expected, &batches); } +#[tokio::test] +async fn test_json_get_str_top_level_sorted() { + let batches = run_query("select name, json_get_str_top_level_sorted(json_data, 'aaa') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_top_level_sorted(test.json_data,Utf8(\"aaa\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = run_query("select name, json_get_str_top_level_sorted(json_data, 'foo') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_top_level_sorted(test.json_data,Utf8(\"foo\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | abc |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = run_query("select name, json_get_str_top_level_sorted(json_data, 'zzz') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_top_level_sorted(test.json_data,Utf8(\"zzz\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_get_str_recursive_level_sorted() { + let batches = run_query("select name, json_get_str_recursive_sorted(json_data, 'aaa') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_recursive_sorted(test.json_data,Utf8(\"aaa\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = run_query("select name, json_get_str_recursive_sorted(json_data, 'foo') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_recursive_sorted(test.json_data,Utf8(\"foo\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | abc |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = run_query("select name, json_get_str_recursive_sorted(json_data, 'zzz') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_recursive_sorted(test.json_data,Utf8(\"zzz\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + #[tokio::test] async fn test_json_get_str_equals() { let sql = "select name, json_get_str(json_data, 'foo')='abc' from test"; @@ -226,6 +346,10 @@ async fn test_json_get_cast_int() { let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string())); + let sql = r#"select json_get_top_level_sorted('{"foo": 42}', 'foo')::int"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string())); + // floats not allowed let sql = r#"select json_get('{"foo": 4.2}', 'foo')::int"#; let batches = run_query(sql).await.unwrap(); @@ -280,6 +404,10 @@ async fn test_json_get_cast_float() { let sql = r#"select json_get('{"foo": 4.2e2}', 'foo')::float"#; let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); + + let sql = r#"select json_get_top_level_sorted('{"foo": 4.2e2}', 'foo')::float"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); } #[tokio::test] @@ -287,6 +415,10 @@ async fn test_json_get_cast_numeric() { let sql = r#"select json_get('{"foo": 4.2e2}', 'foo')::numeric"#; let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); + + let sql = r#"select json_get_top_level_sorted('{"foo": 4.2e2}', 'foo')::numeric"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); } #[tokio::test] @@ -324,7 +456,7 @@ async fn test_json_get_json() { "+------------------+-------------------------------------------+", "| object_foo | \"abc\" |", "| object_foo_array | [1] |", - "| object_foo_obj | {} |", + "| object_foo_obj | {\"bar\": 1} |", "| object_foo_null | null |", "| object_bar | |", "| list_foo | |", @@ -500,7 +632,14 @@ fn test_json_get_utf8() { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - let ColumnarValue::Scalar(sv) = json_get_str.invoke_batch(args, 1).unwrap() else { + let ColumnarValue::Scalar(sv) = json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::Utf8, + }) + .unwrap() + else { panic!("expected scalar") }; @@ -518,7 +657,14 @@ fn test_json_get_large_utf8() { ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("aa".to_string()))), ]; - let ColumnarValue::Scalar(sv) = json_get_str.invoke_batch(args, 1).unwrap() else { + let ColumnarValue::Scalar(sv) = json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::LargeUtf8, + }) + .unwrap() + else { panic!("expected scalar") }; @@ -743,7 +889,7 @@ async fn test_arrow() { "+------------------+-------------------------+", "| object_foo | {str=abc} |", "| object_foo_array | {array=[1]} |", - "| object_foo_obj | {object={}} |", + "| object_foo_obj | {object={\"bar\": 1}} |", "| object_foo_null | {null=} |", "| object_bar | {null=} |", "| list_foo | {null=} |", @@ -775,7 +921,7 @@ async fn test_long_arrow() { "+------------------+--------------------------+", "| object_foo | abc |", "| object_foo_array | [1] |", - "| object_foo_obj | {} |", + "| object_foo_obj | {\"bar\": 1} |", "| object_foo_null | |", "| object_bar | |", "| list_foo | |", diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 541d223..69b169a 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -27,7 +27,7 @@ async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result