From 6a7827a31f30058193caaeb6eb36444031167583 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 21 Feb 2025 01:00:57 -0600 Subject: [PATCH 1/8] add variants of functions optimized for sorted keys --- benches/main.rs | 47 ++++++++++++++- src/common.rs | 38 +++++++++++- src/common_macros.rs | 4 +- src/json_as_text.rs | 37 +++++++++--- src/json_contains.rs | 37 +++++++++--- src/json_get.rs | 38 +++++++++--- src/json_get_bool.rs | 37 ++++++++++-- src/json_get_float.rs | 39 +++++++++--- src/json_get_int.rs | 39 ++++++++++-- src/json_get_json.rs | 37 ++++++++++-- src/json_get_str.rs | 37 ++++++++++-- src/json_length.rs | 40 ++++++++++--- src/json_object_keys.rs | 46 ++++++++++++--- src/lib.rs | 74 ++++++++++++++++------- tests/main.rs | 128 ++++++++++++++++++++++++++++++++++++++-- tests/utils/mod.rs | 2 +- 16 files changed, 580 insertions(+), 100 deletions(-) diff --git a/benches/main.rs b/benches/main.rs index c12c003..fba8800 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -2,7 +2,7 @@ use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criter use datafusion::common::ScalarValue; use datafusion::logical_expr::ColumnarValue; -use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf}; +use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_top_level_sorted_udf, json_get_str_udf}; fn bench_json_contains(b: &mut Bencher) { let json_contains = json_contains_udf(); @@ -30,9 +30,54 @@ fn bench_json_get_str(b: &mut Bencher) { b.iter(|| json_get_str.invoke_batch(args, 1).unwrap()); } +fn make_json_negative_testcase() -> String { + // build a json with keys "b1", "b2" ... "b100", each with a large value ("a" repeated 1024 times) + let kvs = (0..100) + .map(|i| format!(r#""b{}": "{}""#, i, "a".repeat(1024))) + .collect::>() + .join(","); + format!(r#"{{ {} }}"#, kvs) +} + +fn bench_json_get_str_negative(b: &mut Bencher) { + let json_get_str = json_get_str_udf(); + let args = &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(make_json_negative_testcase()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), // lexicographically less than "b1" + ]; + + b.iter(|| json_get_str.invoke_batch(args, 1).unwrap()); +} + +fn bench_json_get_str_sorted(b: &mut Bencher) { + let json_get_str = json_get_str_top_level_sorted_udf(); + let args = &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), + ]; + + b.iter(|| json_get_str.invoke_batch(args, 1).unwrap()); +} + +fn bench_json_get_str_sorted_negative(b: &mut Bencher) { + let json_get_str = json_get_str_top_level_sorted_udf(); + let args = &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(make_json_negative_testcase()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), // lexicographically less than "b1" + ]; + + b.iter(|| json_get_str.invoke_batch(args, 1).unwrap()); +} + fn criterion_benchmark(c: &mut Criterion) { c.bench_function("json_contains", bench_json_contains); c.bench_function("json_get_str", bench_json_get_str); + c.bench_function("json_get_str_negative", bench_json_get_str_negative); + c.bench_function("json_get_str_sorted", bench_json_get_str_sorted); + c.bench_function("json_get_str_sorted_negative", bench_json_get_str_sorted_negative); } criterion_group!(benches, criterion_benchmark); diff --git a/src/common.rs b/src/common.rs index 66505cb..1b9d262 100644 --- a/src/common.rs +++ b/src/common.rs @@ -511,7 +511,11 @@ fn wrap_as_large_dictionary(original: &dyn AnyDictionaryArray, new_values: Array DictionaryArray::new(keys, new_values) } -pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> { +pub fn jiter_json_find<'j>( + opt_json: Option<&'j str>, + path: &[JsonPath], + mut sorted: Sortedness, +) -> Option<(Jiter<'j>, Peek)> { let json_str = opt_json?; let mut jiter = Jiter::new(json_str.as_bytes()); let mut peek = jiter.peek().ok()?; @@ -521,6 +525,11 @@ pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Opti let mut next_key = jiter.known_object().ok()??; while next_key != *key { + if next_key > *key && matches!(sorted, Sortedness::Recursive | Sortedness::TopLevel) { + // The current object is sorted and next_key is lexicographically greater than key + // we are looking for, so we can early stop here. + return None; + } jiter.next_skip().ok()?; next_key = jiter.next_key().ok()??; } @@ -541,6 +550,9 @@ pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Opti return None; } } + if sorted == Sortedness::TopLevel { + sorted = Sortedness::Unspecified; + } } Some((jiter, peek)) } @@ -585,3 +597,27 @@ fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> Pr } PrimitiveArray::new(keys.values().clone(), Some(null_mask.into())) } + +/// Information about the sortedness of a JSON object. +/// This is used to optimize key lookups by early stopping when the key we are looking for is +/// lexicographically greater than the current key and the object is known to be sorted. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) enum Sortedness { + /// No guarantees about the order of the elements. + Unspecified, + /// Only the outermost object is known to be sorted. + /// If the outermost item is not an object, this is equivalent to `Unspecified`. + TopLevel, + /// All objects are known to be sorted, including objects nested within arrays. + Recursive, +} + +impl Sortedness { + pub(crate) fn function_name_suffix(self) -> &'static str { + match self { + Sortedness::Unspecified => "", + Sortedness::TopLevel => "_top_level_sorted", + Sortedness::Recursive => "_recursive_sorted", + } + } +} diff --git a/src/common_macros.rs b/src/common_macros.rs index a2c6cd0..8435ffa 100644 --- a/src/common_macros.rs +++ b/src/common_macros.rs @@ -15,7 +15,7 @@ /// /// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl macro_rules! make_udf_function { - ($udf_impl:ty, $expr_fn_name:ident, $($arg:ident)*, $doc:expr) => { + ($udf_impl:ty, $expr_fn_name:ident, $($arg:ident)*, $doc:expr, $sorted:expr) => { paste::paste! { #[doc = $doc] #[must_use] pub fn $expr_fn_name($($arg: datafusion::logical_expr::Expr),*) -> datafusion::logical_expr::Expr { @@ -37,7 +37,7 @@ macro_rules! make_udf_function { [< STATIC_ $expr_fn_name:upper >] .get_or_init(|| { std::sync::Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl( - <$udf_impl>::default(), + <$udf_impl>::new($sorted), )) }) .clone() diff --git a/src/json_as_text.rs b/src/json_as_text.rs index ce73375..325980d 100644 --- a/src/json_as_text.rs +++ b/src/json_as_text.rs @@ -7,27 +7,48 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common::{ + get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness, +}; use crate::common_macros::make_udf_function; make_udf_function!( JsonAsText, json_as_text, json_data path, - r#"Get any value from a JSON string by its "path", represented as a string"# + r#"Get any value from a JSON string by its "path", represented as a string"#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonAsText, + json_as_text_top_level_sorted, + json_data path, + r#"Get any value from a JSON string by its "path", represented as a string; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonAsText, + json_as_text_recursive_sorted, + json_data path, + r#"Get any value from a JSON string by its "path", represented as a string; assumes all json object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonAsText { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } -impl Default for JsonAsText { - fn default() -> Self { +impl JsonAsText { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_as_text".to_string()], + aliases: [format!("json_as_text{}", sorted.function_name_suffix())], + sorted, } } } @@ -50,7 +71,7 @@ impl ScalarUDFImpl for JsonAsText { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_as_text) + invoke::(args, |json, path| jiter_json_as_text(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -82,8 +103,8 @@ impl InvokeResult for StringArray { } } -fn jiter_json_as_text(opt_json: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { +fn jiter_json_as_text(opt_json: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) { match peek { Peek::Null => { jiter.known_null()?; diff --git a/src/json_contains.rs b/src/json_contains.rs index 0eda035..fdcce48 100644 --- a/src/json_contains.rs +++ b/src/json_contains.rs @@ -14,20 +14,39 @@ make_udf_function!( JsonContains, json_contains, json_data path, - r#"Does the key/index exist within the JSON value as the specified "path"?"# + r#"Does the key/index exist within the JSON value as the specified "path"?"#, + crate::common::Sortedness::Unspecified +); + +make_udf_function!( + JsonContains, + json_contains_top_level_sorted, + json_data path, + r#"Does the key/index exist within the JSON value as the specified "path"; assumes the JSON string's top level object's keys are sorted?"#, + crate::common::Sortedness::TopLevel +); + +make_udf_function!( + JsonContains, + json_contains_recursive_sorted, + json_data path, + r#"Does the key/index exist within the JSON value as the specified "path"; assumes all json object's keys are sorted?"#, + crate::common::Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonContains { signature: Signature, aliases: [String; 1], + sorted: crate::common::Sortedness, } -impl Default for JsonContains { - fn default() -> Self { +impl JsonContains { + pub fn new(sorted: crate::common::Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_contains".to_string()], + aliases: [format!("json_contains{}", sorted.function_name_suffix())], + sorted, } } } @@ -54,7 +73,7 @@ impl ScalarUDFImpl for JsonContains { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - invoke::(args, jiter_json_contains) + invoke::(args, |json, path| jiter_json_contains(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -88,6 +107,10 @@ impl InvokeResult for BooleanArray { } #[allow(clippy::unnecessary_wraps)] -fn jiter_json_contains(json_data: Option<&str>, path: &[JsonPath]) -> Result { - Ok(jiter_json_find(json_data, path).is_some()) +fn jiter_json_contains( + json_data: Option<&str>, + path: &[JsonPath], + sorted: crate::common::Sortedness, +) -> Result { + Ok(jiter_json_find(json_data, path, sorted).is_some()) } diff --git a/src/json_get.rs b/src/json_get.rs index c0ba131..67bcb42 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -10,6 +10,7 @@ use datafusion::scalar::ScalarValue; use jiter::{Jiter, NumberAny, NumberInt, Peek}; use crate::common::InvokeResult; +use crate::common::Sortedness; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; use crate::common_union::{JsonUnion, JsonUnionField}; @@ -18,22 +19,39 @@ make_udf_function!( JsonGet, json_get, json_data path, - r#"Get a value from a JSON string by its "path""# + r#"Get a value from a JSON string by its "path""#, + Sortedness::Unspecified ); -// build_typed_get!(JsonGet, "json_get", Union, Float64Array, jiter_json_get_float); +make_udf_function!( + JsonGet, + json_get_top_level_sorted, + json_data path, + r#"Get a value from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGet, + json_get_recursive_sorted, + json_data path, + r#"Get a value from a JSON string by its "path"; assumes all object's keys are sorted."#, + Sortedness::Recursive +); #[derive(Debug)] pub(super) struct JsonGet { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } -impl Default for JsonGet { - fn default() -> Self { +impl JsonGet { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get".to_string()], + aliases: [format!("json_get{}", sorted.function_name_suffix())], + sorted, } } } @@ -56,7 +74,7 @@ impl ScalarUDFImpl for JsonGet { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_union) + invoke::(args, |json, path| jiter_json_get_union(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -93,8 +111,12 @@ impl InvokeResult for JsonUnion { } } -fn jiter_json_get_union(opt_json: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { +fn jiter_json_get_union( + opt_json: Option<&str>, + path: &[JsonPath], + sorted: Sortedness, +) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) { build_union(&mut jiter, peek) } else { get_err!() diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs index 2bae88b..234c6e4 100644 --- a/src/json_get_bool.rs +++ b/src/json_get_bool.rs @@ -6,27 +6,52 @@ use datafusion::common::Result as DataFusionResult; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath, Sortedness}; use crate::common_macros::make_udf_function; make_udf_function!( JsonGetBool, json_get_bool, json_data path, - r#"Get an boolean value from a JSON string by its "path""# + r#"Get an boolean value from a JSON string by its "path""#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonGetBool, + json_get_bool_top_level_sorted, + json_data path, + r#"Get an boolean value from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGetBool, + json_get_bool_recursive_sorted, + json_data path, + r#"Get an boolean value from a JSON string by its "path"; assumes all json object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonGetBool { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } impl Default for JsonGetBool { fn default() -> Self { + Self::new(Sortedness::Unspecified) + } +} + +impl JsonGetBool { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_bool".to_string()], + aliases: [format!("json_get_bool{}", sorted.function_name_suffix())], + sorted, } } } @@ -49,7 +74,7 @@ impl ScalarUDFImpl for JsonGetBool { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_bool) + invoke::(args, |json, path| jiter_json_get_bool(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -57,8 +82,8 @@ impl ScalarUDFImpl for JsonGetBool { } } -fn jiter_json_get_bool(json_data: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { +fn jiter_json_get_bool(json_data: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path, sorted) { match peek { Peek::True | Peek::False => Ok(jiter.known_bool(peek)?), _ => get_err!(), diff --git a/src/json_get_float.rs b/src/json_get_float.rs index 24b00a5..8a9e4df 100644 --- a/src/json_get_float.rs +++ b/src/json_get_float.rs @@ -7,27 +7,48 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberAny, Peek}; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common::{ + get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness, +}; use crate::common_macros::make_udf_function; make_udf_function!( JsonGetFloat, json_get_float, json_data path, - r#"Get a float value from a JSON string by its "path""# + r#"Get a float value from a JSON string by its "path""#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonGetFloat, + json_get_float_top_level_sorted, + json_data path, + r#"Get an float value from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGetFloat, + json_get_float_recursive_sorted, + json_data path, + r#"Get an float value from a JSON string by its "path"; assumes all object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonGetFloat { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } -impl Default for JsonGetFloat { - fn default() -> Self { +impl JsonGetFloat { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_float".to_string()], + aliases: [format!("json_get_float{}", sorted.function_name_suffix())], + sorted, } } } @@ -50,7 +71,9 @@ impl ScalarUDFImpl for JsonGetFloat { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_float) + invoke::(args, |json_data, path| { + jiter_json_get_float(json_data, path, self.sorted) + }) } fn aliases(&self) -> &[String] { @@ -83,8 +106,8 @@ impl InvokeResult for Float64Array { } } -fn jiter_json_get_float(json_data: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { +fn jiter_json_get_float(json_data: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path, sorted) { match peek { // numbers are represented by everything else in peek, hence doing it this way Peek::Null diff --git a/src/json_get_int.rs b/src/json_get_int.rs index a6a2b90..f1f16e3 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -7,27 +7,54 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberInt, Peek}; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common::{ + get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness, +}; use crate::common_macros::make_udf_function; make_udf_function!( JsonGetInt, json_get_int, json_data path, - r#"Get an integer value from a JSON string by its "path""# + r#"Get an integer value from a JSON string by its "path""#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonGetInt, + json_get_int_top_level_sorted, + json_data path, + r#"Get an integer value from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGetInt, + json_get_int_recursive_sorted, + json_data path, + r#"Get an integer value from a JSON string by its "path"; assumes all json object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonGetInt { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } impl Default for JsonGetInt { fn default() -> Self { + Self::new(Sortedness::Unspecified) + } +} + +impl JsonGetInt { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_int".to_string()], + aliases: [format!("json_get_int{}", sorted.function_name_suffix())], + sorted, } } } @@ -50,7 +77,7 @@ impl ScalarUDFImpl for JsonGetInt { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_int) + invoke::(args, |json, path| jiter_json_get_int(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -83,8 +110,8 @@ impl InvokeResult for Int64Array { } } -fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { +fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path, sorted) { match peek { // numbers are represented by everything else in peek, hence doing it this way Peek::Null diff --git a/src/json_get_json.rs b/src/json_get_json.rs index c5890b4..cb3b89a 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -5,27 +5,52 @@ use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath, Sortedness}; use crate::common_macros::make_udf_function; make_udf_function!( JsonGetJson, json_get_json, json_data path, - r#"Get a nested raw JSON string from a JSON string by its "path""# + r#"Get a nested raw JSON string from a JSON string by its "path""#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonGetJson, + json_get_json_top_level_sorted, + json_data path, + r#"Get a nested raw JSON string from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGetJson, + json_get_json_recursive_sorted, + json_data path, + r#"Get a nested raw JSON string from a JSON string by its "path"; assumes all json object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonGetJson { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } impl Default for JsonGetJson { fn default() -> Self { + Self::new(Sortedness::Unspecified) + } +} + +impl JsonGetJson { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_json".to_string()], + aliases: [format!("json_get_json{}", sorted.function_name_suffix())], + sorted, } } } @@ -48,7 +73,7 @@ impl ScalarUDFImpl for JsonGetJson { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_json) + invoke::(args, |json, path| jiter_json_get_json(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -56,8 +81,8 @@ impl ScalarUDFImpl for JsonGetJson { } } -fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { +fn jiter_json_get_json(opt_json: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) { let start = jiter.current_index(); jiter.known_skip(peek)?; let object_slice = jiter.slice_to_current(start); diff --git a/src/json_get_str.rs b/src/json_get_str.rs index 633321f..a88521a 100644 --- a/src/json_get_str.rs +++ b/src/json_get_str.rs @@ -6,27 +6,52 @@ use datafusion::common::Result as DataFusionResult; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath, Sortedness}; use crate::common_macros::make_udf_function; make_udf_function!( JsonGetStr, json_get_str, json_data path, - r#"Get a string value from a JSON string by its "path""# + r#"Get a string value from a JSON string by its "path""#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonGetStr, + json_get_str_top_level_sorted, + json_data path, + r#"Get a string value from a JSON string by its "path"; assumes the JSON string's top level object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonGetStr, + json_get_str_recursive_sorted, + json_data path, + r#"Get a string value from a JSON string by its "path"; assumes all json object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonGetStr { signature: Signature, aliases: [String; 1], + sorted: Sortedness, } impl Default for JsonGetStr { fn default() -> Self { + Self::new(Sortedness::Unspecified) + } +} + +impl JsonGetStr { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_str".to_string()], + aliases: [format!("json_get_str{}", sorted.function_name_suffix())], + sorted, } } } @@ -49,7 +74,7 @@ impl ScalarUDFImpl for JsonGetStr { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_str) + invoke::(args, |json, path| jiter_json_get_str(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -57,8 +82,8 @@ impl ScalarUDFImpl for JsonGetStr { } } -fn jiter_json_get_str(json_data: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { +fn jiter_json_get_str(json_data: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path, sorted) { match peek { Peek::String => Ok(jiter.known_str()?.to_owned()), _ => get_err!(), diff --git a/src/json_length.rs b/src/json_length.rs index 8d0bfd3..7bbce4f 100644 --- a/src/json_length.rs +++ b/src/json_length.rs @@ -7,27 +7,51 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common::{ + get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness, +}; use crate::common_macros::make_udf_function; make_udf_function!( JsonLength, json_length, json_data path, - r"Get the length of the array or object at the given path." + r#"Get the length of the array or object at the given path."#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonLength, + json_length_top_level_sorted, + json_data path, + r#"Get the length of the array or object at the given path; assumes the JSON object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonLength, + json_length_recursive_sorted, + json_data path, + r#"Get the length of the array or object at the given path; assumes all object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonLength { signature: Signature, aliases: [String; 2], + sorted: Sortedness, } -impl Default for JsonLength { - fn default() -> Self { +impl JsonLength { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_length".to_string(), "json_len".to_string()], + aliases: [ + format!("json_length{}", sorted.function_name_suffix()), + format!("json_len{}", sorted.function_name_suffix()), + ], + sorted, } } } @@ -50,7 +74,7 @@ impl ScalarUDFImpl for JsonLength { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_length) + invoke::(args, |json, path| jiter_json_length(json, path, self.sorted)) } fn aliases(&self) -> &[String] { @@ -83,8 +107,8 @@ impl InvokeResult for UInt64Array { } } -fn jiter_json_length(opt_json: Option<&str>, path: &[JsonPath]) -> Result { - if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { +fn jiter_json_length(opt_json: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) { match peek { Peek::Array => { let mut peek_opt = jiter.known_array()?; diff --git a/src/json_object_keys.rs b/src/json_object_keys.rs index 04cac0e..706f438 100644 --- a/src/json_object_keys.rs +++ b/src/json_object_keys.rs @@ -7,27 +7,51 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; +use crate::common::{ + get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness, +}; use crate::common_macros::make_udf_function; make_udf_function!( JsonObjectKeys, json_object_keys, json_data path, - r"Get the keys of a JSON object as an array." + r#"Get the keys of a JSON object as an array."#, + Sortedness::Unspecified +); + +make_udf_function!( + JsonObjectKeys, + json_keys_sorted, + json_data path, + r#"Get the keys of a JSON object as an array; assumes the JSON object's keys are sorted."#, + Sortedness::TopLevel +); + +make_udf_function!( + JsonObjectKeys, + json_keys_recursive_sorted, + json_data path, + r#"Get the keys of a JSON object as an array; assumes all object's keys are sorted."#, + Sortedness::Recursive ); #[derive(Debug)] pub(super) struct JsonObjectKeys { signature: Signature, aliases: [String; 2], + sorted: Sortedness, } -impl Default for JsonObjectKeys { - fn default() -> Self { +impl JsonObjectKeys { + pub fn new(sorted: Sortedness) -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_object_keys".to_string(), "json_keys".to_string()], + aliases: [ + format!("json_object_keys{}", sorted.function_name_suffix()), + format!("json_keys{}", sorted.function_name_suffix()), + ], + sorted, } } } @@ -54,7 +78,9 @@ impl ScalarUDFImpl for JsonObjectKeys { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_object_keys) + invoke::(args, |opt_json, path| { + jiter_json_object_keys(opt_json, path, self.sorted) + }) } fn aliases(&self) -> &[String] { @@ -106,8 +132,12 @@ fn keys_to_scalar(opt_keys: Option>) -> ScalarValue { ScalarValue::List(Arc::new(array)) } -fn jiter_json_object_keys(opt_json: Option<&str>, path: &[JsonPath]) -> Result, GetError> { - if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { +fn jiter_json_object_keys( + opt_json: Option<&str>, + path: &[JsonPath], + sorted: Sortedness, +) -> Result, GetError> { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) { match peek { Peek::Object => { let mut opt_key = jiter.known_object()?; diff --git a/src/lib.rs b/src/lib.rs index cb0f25a..10fabbc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,29 +23,43 @@ mod rewrite; pub use common_union::{JsonUnionEncoder, JsonUnionValue}; pub mod functions { - pub use crate::json_as_text::json_as_text; - pub use crate::json_contains::json_contains; - pub use crate::json_get::json_get; - pub use crate::json_get_bool::json_get_bool; - pub use crate::json_get_float::json_get_float; - pub use crate::json_get_int::json_get_int; - pub use crate::json_get_json::json_get_json; - pub use crate::json_get_str::json_get_str; - pub use crate::json_length::json_length; - pub use crate::json_object_keys::json_object_keys; + pub use crate::json_as_text::{json_as_text, json_as_text_recursive_sorted, json_as_text_top_level_sorted}; + pub use crate::json_contains::{json_contains, json_contains_recursive_sorted, json_contains_top_level_sorted}; + pub use crate::json_get::{json_get, json_get_recursive_sorted, json_get_top_level_sorted}; + pub use crate::json_get_bool::{json_get_bool, json_get_bool_recursive_sorted, json_get_bool_top_level_sorted}; + pub use crate::json_get_float::{json_get_float, json_get_float_recursive_sorted, json_get_float_top_level_sorted}; + pub use crate::json_get_int::{json_get_int, json_get_int_recursive_sorted, json_get_int_top_level_sorted}; + pub use crate::json_get_json::{json_get_json, json_get_json_recursive_sorted, json_get_json_top_level_sorted}; + pub use crate::json_get_str::{json_get_str, json_get_str_recursive_sorted, json_get_str_top_level_sorted}; + pub use crate::json_length::{json_length, json_length_recursive_sorted, json_length_top_level_sorted}; + pub use crate::json_object_keys::{json_keys_recursive_sorted, json_keys_sorted, json_object_keys}; } pub mod udfs { - pub use crate::json_as_text::json_as_text_udf; - pub use crate::json_contains::json_contains_udf; - pub use crate::json_get::json_get_udf; - pub use crate::json_get_bool::json_get_bool_udf; - pub use crate::json_get_float::json_get_float_udf; - pub use crate::json_get_int::json_get_int_udf; - pub use crate::json_get_json::json_get_json_udf; - pub use crate::json_get_str::json_get_str_udf; - pub use crate::json_length::json_length_udf; - pub use crate::json_object_keys::json_object_keys_udf; + pub use crate::json_as_text::{ + json_as_text_recursive_sorted_udf, json_as_text_top_level_sorted_udf, json_as_text_udf, + }; + pub use crate::json_contains::{ + json_contains_recursive_sorted_udf, json_contains_top_level_sorted_udf, json_contains_udf, + }; + pub use crate::json_get::{json_get_recursive_sorted_udf, json_get_top_level_sorted_udf, json_get_udf}; + pub use crate::json_get_bool::{ + json_get_bool_recursive_sorted_udf, json_get_bool_top_level_sorted_udf, json_get_bool_udf, + }; + pub use crate::json_get_float::{ + json_get_float_recursive_sorted_udf, json_get_float_top_level_sorted_udf, json_get_float_udf, + }; + pub use crate::json_get_int::{ + json_get_int_recursive_sorted_udf, json_get_int_top_level_sorted_udf, json_get_int_udf, + }; + pub use crate::json_get_json::{ + json_get_json_recursive_sorted_udf, json_get_json_top_level_sorted_udf, json_get_json_udf, + }; + pub use crate::json_get_str::{ + json_get_str_recursive_sorted_udf, json_get_str_top_level_sorted_udf, json_get_str_udf, + }; + pub use crate::json_length::{json_length_recursive_sorted_udf, json_length_top_level_sorted_udf, json_length_udf}; + pub use crate::json_object_keys::{json_keys_recursive_sorted_udf, json_keys_sorted_udf, json_object_keys_udf}; } /// Register all JSON UDFs, and [`rewrite::JsonFunctionRewriter`] with the provided [`FunctionRegistry`]. @@ -60,15 +74,35 @@ pub mod udfs { pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ json_get::json_get_udf(), + json_get::json_get_top_level_sorted_udf(), + json_get::json_get_recursive_sorted_udf(), json_get_bool::json_get_bool_udf(), + json_get_bool::json_get_bool_top_level_sorted_udf(), + json_get_bool::json_get_bool_recursive_sorted_udf(), json_get_float::json_get_float_udf(), + json_get_float::json_get_float_top_level_sorted_udf(), + json_get_float::json_get_float_recursive_sorted_udf(), json_get_int::json_get_int_udf(), + json_get_int::json_get_int_top_level_sorted_udf(), + json_get_int::json_get_int_recursive_sorted_udf(), json_get_json::json_get_json_udf(), + json_get_json::json_get_json_top_level_sorted_udf(), + json_get_json::json_get_json_recursive_sorted_udf(), json_as_text::json_as_text_udf(), + json_as_text::json_as_text_top_level_sorted_udf(), + json_as_text::json_as_text_recursive_sorted_udf(), json_get_str::json_get_str_udf(), + json_get_str::json_get_str_top_level_sorted_udf(), + json_get_str::json_get_str_recursive_sorted_udf(), json_contains::json_contains_udf(), + json_contains::json_contains_top_level_sorted_udf(), + json_contains::json_contains_recursive_sorted_udf(), json_length::json_length_udf(), + json_length::json_length_top_level_sorted_udf(), + json_length::json_length_recursive_sorted_udf(), json_object_keys::json_object_keys_udf(), + json_object_keys::json_keys_sorted_udf(), + json_object_keys::json_keys_recursive_sorted_udf(), ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/tests/main.rs b/tests/main.rs index 0e65338..5bc542d 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -72,7 +72,7 @@ async fn test_json_get_union() { "+------------------+--------------------------------------+", "| object_foo | {str=abc} |", "| object_foo_array | {array=[1]} |", - "| object_foo_obj | {object={}} |", + "| object_foo_obj | {object={\"bar\": 1}} |", "| object_foo_null | {null=} |", "| object_bar | {null=} |", "| list_foo | {null=} |", @@ -147,6 +147,126 @@ async fn test_json_get_str() { assert_batches_eq!(expected, &batches); } +#[tokio::test] +async fn test_json_get_str_top_level_sorted() { + let batches = run_query("select name, json_get_str_top_level_sorted(json_data, 'aaa') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_top_level_sorted(test.json_data,Utf8(\"aaa\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = run_query("select name, json_get_str_top_level_sorted(json_data, 'foo') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_top_level_sorted(test.json_data,Utf8(\"foo\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | abc |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = run_query("select name, json_get_str_top_level_sorted(json_data, 'zzz') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_top_level_sorted(test.json_data,Utf8(\"zzz\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_get_str_recursive_level_sorted() { + let batches = run_query("select name, json_get_str_recursive_sorted(json_data, 'aaa') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_recursive_sorted(test.json_data,Utf8(\"aaa\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = run_query("select name, json_get_str_recursive_sorted(json_data, 'foo') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_recursive_sorted(test.json_data,Utf8(\"foo\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | abc |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); + + let batches = run_query("select name, json_get_str_recursive_sorted(json_data, 'zzz') from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------------------------+", + "| name | json_get_str_recursive_sorted(test.json_data,Utf8(\"zzz\")) |", + "+------------------+-----------------------------------------------------------+", + "| object_foo | |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + #[tokio::test] async fn test_json_get_str_equals() { let sql = "select name, json_get_str(json_data, 'foo')='abc' from test"; @@ -324,7 +444,7 @@ async fn test_json_get_json() { "+------------------+-------------------------------------------+", "| object_foo | \"abc\" |", "| object_foo_array | [1] |", - "| object_foo_obj | {} |", + "| object_foo_obj | {\"bar\": 1} |", "| object_foo_null | null |", "| object_bar | |", "| list_foo | |", @@ -743,7 +863,7 @@ async fn test_arrow() { "+------------------+-------------------------+", "| object_foo | {str=abc} |", "| object_foo_array | {array=[1]} |", - "| object_foo_obj | {object={}} |", + "| object_foo_obj | {object={\"bar\": 1}} |", "| object_foo_null | {null=} |", "| object_bar | {null=} |", "| list_foo | {null=} |", @@ -775,7 +895,7 @@ async fn test_long_arrow() { "+------------------+--------------------------+", "| object_foo | abc |", "| object_foo_array | [1] |", - "| object_foo_obj | {} |", + "| object_foo_obj | {\"bar\": 1} |", "| object_foo_null | |", "| object_bar | |", "| list_foo | |", diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 541d223..69b169a 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -27,7 +27,7 @@ async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result Date: Fri, 21 Feb 2025 01:10:34 -0600 Subject: [PATCH 2/8] lint --- benches/main.rs | 2 +- src/json_get_bool.rs | 6 ------ src/json_get_int.rs | 6 ------ src/json_get_json.rs | 6 ------ src/json_get_str.rs | 6 ------ 5 files changed, 1 insertion(+), 25 deletions(-) diff --git a/benches/main.rs b/benches/main.rs index fba8800..a475e48 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -36,7 +36,7 @@ fn make_json_negative_testcase() -> String { .map(|i| format!(r#""b{}": "{}""#, i, "a".repeat(1024))) .collect::>() .join(","); - format!(r#"{{ {} }}"#, kvs) + format!(r"{{ {kvs} }}") } fn bench_json_get_str_negative(b: &mut Bencher) { diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs index 234c6e4..7987b18 100644 --- a/src/json_get_bool.rs +++ b/src/json_get_bool.rs @@ -40,12 +40,6 @@ pub(super) struct JsonGetBool { sorted: Sortedness, } -impl Default for JsonGetBool { - fn default() -> Self { - Self::new(Sortedness::Unspecified) - } -} - impl JsonGetBool { pub fn new(sorted: Sortedness) -> Self { Self { diff --git a/src/json_get_int.rs b/src/json_get_int.rs index f1f16e3..c301370 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -43,12 +43,6 @@ pub(super) struct JsonGetInt { sorted: Sortedness, } -impl Default for JsonGetInt { - fn default() -> Self { - Self::new(Sortedness::Unspecified) - } -} - impl JsonGetInt { pub fn new(sorted: Sortedness) -> Self { Self { diff --git a/src/json_get_json.rs b/src/json_get_json.rs index cb3b89a..f76fef4 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -39,12 +39,6 @@ pub(super) struct JsonGetJson { sorted: Sortedness, } -impl Default for JsonGetJson { - fn default() -> Self { - Self::new(Sortedness::Unspecified) - } -} - impl JsonGetJson { pub fn new(sorted: Sortedness) -> Self { Self { diff --git a/src/json_get_str.rs b/src/json_get_str.rs index a88521a..4552022 100644 --- a/src/json_get_str.rs +++ b/src/json_get_str.rs @@ -40,12 +40,6 @@ pub(super) struct JsonGetStr { sorted: Sortedness, } -impl Default for JsonGetStr { - fn default() -> Self { - Self::new(Sortedness::Unspecified) - } -} - impl JsonGetStr { pub fn new(sorted: Sortedness) -> Self { Self { From a85e124c088ef6f99573b6b39fef1d6c546cd8d0 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 21 Feb 2025 01:18:56 -0600 Subject: [PATCH 3/8] lint --- src/json_length.rs | 6 +++--- src/json_object_keys.rs | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/json_length.rs b/src/json_length.rs index 7bbce4f..884cb3d 100644 --- a/src/json_length.rs +++ b/src/json_length.rs @@ -16,7 +16,7 @@ make_udf_function!( JsonLength, json_length, json_data path, - r#"Get the length of the array or object at the given path."#, + r"Get the length of the array or object at the given path.", Sortedness::Unspecified ); @@ -24,7 +24,7 @@ make_udf_function!( JsonLength, json_length_top_level_sorted, json_data path, - r#"Get the length of the array or object at the given path; assumes the JSON object's keys are sorted."#, + r"Get the length of the array or object at the given path; assumes the JSON object's keys are sorted.", Sortedness::TopLevel ); @@ -32,7 +32,7 @@ make_udf_function!( JsonLength, json_length_recursive_sorted, json_data path, - r#"Get the length of the array or object at the given path; assumes all object's keys are sorted."#, + r"Get the length of the array or object at the given path; assumes all object's keys are sorted.", Sortedness::Recursive ); diff --git a/src/json_object_keys.rs b/src/json_object_keys.rs index 706f438..cc4e7b4 100644 --- a/src/json_object_keys.rs +++ b/src/json_object_keys.rs @@ -16,7 +16,7 @@ make_udf_function!( JsonObjectKeys, json_object_keys, json_data path, - r#"Get the keys of a JSON object as an array."#, + r"Get the keys of a JSON object as an array.", Sortedness::Unspecified ); @@ -24,7 +24,7 @@ make_udf_function!( JsonObjectKeys, json_keys_sorted, json_data path, - r#"Get the keys of a JSON object as an array; assumes the JSON object's keys are sorted."#, + r"Get the keys of a JSON object as an array; assumes the JSON object's keys are sorted.", Sortedness::TopLevel ); @@ -32,7 +32,7 @@ make_udf_function!( JsonObjectKeys, json_keys_recursive_sorted, json_data path, - r#"Get the keys of a JSON object as an array; assumes all object's keys are sorted."#, + r"Get the keys of a JSON object as an array; assumes all object's keys are sorted.", Sortedness::Recursive ); From b1215fbf2a4855f3f1c92473550b8b9d3d1bb2e1 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:31:56 -0600 Subject: [PATCH 4/8] fix optimizations --- src/common.rs | 8 ++++++++ src/rewrite.rs | 36 +++++++++++++++++++++++++----------- tests/main.rs | 12 ++++++++++++ 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/src/common.rs b/src/common.rs index 1b9d262..5bfd079 100644 --- a/src/common.rs +++ b/src/common.rs @@ -612,6 +612,14 @@ pub(crate) enum Sortedness { Recursive, } +impl Sortedness { + pub(crate) fn iter() -> impl Iterator { + [Sortedness::Unspecified, Sortedness::TopLevel, Sortedness::Recursive] + .iter() + .copied() + } +} + impl Sortedness { pub(crate) fn function_name_suffix(self) -> &'static str { match self { diff --git a/src/rewrite.rs b/src/rewrite.rs index 4161154..ce05d51 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::sync::LazyLock; use datafusion::arrow::datatypes::DataType; use datafusion::common::config::ConfigOptions; @@ -11,8 +12,12 @@ use datafusion::logical_expr::expr_rewriter::FunctionRewrite; use datafusion::logical_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; use datafusion::logical_expr::sqlparser::ast::BinaryOperator; use datafusion::logical_expr::ScalarUDF; +use datafusion::logical_expr::ScalarUDFImpl; use datafusion::scalar::ScalarValue; +use crate::common::Sortedness; +use crate::json_get::JsonGet; + #[derive(Debug)] pub(crate) struct JsonFunctionRewriter; @@ -31,11 +36,14 @@ impl FunctionRewrite for JsonFunctionRewriter { } } +static JSON_GET_FUNC_NAMES: LazyLock> = + LazyLock::new(|| Sortedness::iter().map(|s| JsonGet::new(s).name().to_string()).collect()); + /// This replaces `get_json(foo, bar)::int` with `json_get_int(foo, bar)` so the JSON function can take care of /// extracting the right value type from JSON without the need to materialize the JSON union. fn optimise_json_get_cast(cast: &Cast) -> Option> { let scalar_func = extract_scalar_function(&cast.expr)?; - if scalar_func.func.name() != "json_get" { + if !JSON_GET_FUNC_NAMES.contains(&scalar_func.func.name().to_owned()) { return None; } let func = match &cast.data_type { @@ -53,18 +61,24 @@ fn optimise_json_get_cast(cast: &Cast) -> Option> { }))) } +static JSON_FUNCTION_NAMES: LazyLock> = LazyLock::new(|| { + Sortedness::iter() + .flat_map(|s| { + [ + crate::json_get::JsonGet::new(s).name().to_string(), + crate::json_get_bool::JsonGetBool::new(s).name().to_string(), + crate::json_get_float::JsonGetFloat::new(s).name().to_string(), + crate::json_get_int::JsonGetInt::new(s).name().to_string(), + crate::json_get_str::JsonGetStr::new(s).name().to_string(), + crate::json_as_text::JsonAsText::new(s).name().to_string(), + ] + }) + .collect() +}); + // Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')` fn unnest_json_calls(func: &ScalarFunction) -> Option> { - if !matches!( - func.func.name(), - "json_get" - | "json_get_bool" - | "json_get_float" - | "json_get_int" - | "json_get_json" - | "json_get_str" - | "json_as_text" - ) { + if !JSON_FUNCTION_NAMES.contains(&func.func.name().to_owned()) { return None; } let mut outer_args_iter = func.args.iter(); diff --git a/tests/main.rs b/tests/main.rs index 5bc542d..4268016 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -346,6 +346,10 @@ async fn test_json_get_cast_int() { let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string())); + let sql = r#"select json_get_top_level_sorted('{"foo": 42}', 'foo')::int"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string())); + // floats not allowed let sql = r#"select json_get('{"foo": 4.2}', 'foo')::int"#; let batches = run_query(sql).await.unwrap(); @@ -400,6 +404,10 @@ async fn test_json_get_cast_float() { let sql = r#"select json_get('{"foo": 4.2e2}', 'foo')::float"#; let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); + + let sql = r#"select json_get_top_level_sorted('{"foo": 4.2e2}', 'foo')::float"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); } #[tokio::test] @@ -407,6 +415,10 @@ async fn test_json_get_cast_numeric() { let sql = r#"select json_get('{"foo": 4.2e2}', 'foo')::numeric"#; let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); + + let sql = r#"select json_get_top_level_sorted('{"foo": 4.2e2}', 'foo')::numeric"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Float64, "420.0".to_string())); } #[tokio::test] From 58f56cdff88833f75be09866a781882adf16c427 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 27 Feb 2025 10:38:16 -0600 Subject: [PATCH 5/8] fix bug in unnest --- src/rewrite.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/rewrite.rs b/src/rewrite.rs index ce05d51..94c2e08 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -91,6 +91,10 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option> { return None; } + if func.name() != inner_func.func.name() { + return None; + } + let mut args = inner_func.args.clone(); args.extend(outer_args_iter.cloned()); // See #23, unnest only when all lookup arguments are literals From 1dc59bd6a53562c818f5e74340555ac838a0e1f4 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 27 Feb 2025 13:48:40 -0600 Subject: [PATCH 6/8] revert --- src/rewrite.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/rewrite.rs b/src/rewrite.rs index 94c2e08..ce05d51 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -91,10 +91,6 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option> { return None; } - if func.name() != inner_func.func.name() { - return None; - } - let mut args = inner_func.args.clone(); args.extend(outer_args_iter.cloned()); // See #23, unnest only when all lookup arguments are literals From aff3e71318f1db8ab0e8851d1cc1e26693f6ff9a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:01:52 -0500 Subject: [PATCH 7/8] Prepare for DataFusion 47 --- Cargo.toml | 8 ++++++++ benches/main.rs | 25 +++++++++++++++++++++---- src/json_as_text.rs | 6 +++--- src/json_contains.rs | 6 +++--- src/json_get.rs | 6 +++--- src/json_get_bool.rs | 6 +++--- src/json_get_float.rs | 6 +++--- src/json_get_int.rs | 6 +++--- src/json_get_json.rs | 6 +++--- src/json_get_str.rs | 6 +++--- src/json_length.rs | 6 +++--- src/json_object_keys.rs | 6 +++--- tests/main.rs | 20 +++++++++++++++++--- 13 files changed, 76 insertions(+), 37 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4dbb883..de85981 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,11 @@ pedantic = { level = "deny", priority = -1 } [[bench]] name = "main" harness = false + +[patch.crates-io] +datafusion = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } +datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } +datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } +datafusion-functions-aggregate = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } +datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } +datafusion-macros = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } diff --git a/benches/main.rs b/benches/main.rs index c12c003..5c3eac5 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -1,12 +1,13 @@ use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criterion}; -use datafusion::common::ScalarValue; +use datafusion::arrow::datatypes::DataType; use datafusion::logical_expr::ColumnarValue; +use datafusion::{common::ScalarValue, logical_expr::ScalarFunctionArgs}; use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf}; fn bench_json_contains(b: &mut Bencher) { let json_contains = json_contains_udf(); - let args = &[ + let args = vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some( r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), ))), @@ -14,7 +15,15 @@ fn bench_json_contains(b: &mut Bencher) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - b.iter(|| json_contains.invoke_batch(args, 1).unwrap()); + b.iter(|| { + json_contains + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: 1, + return_type: &DataType::Boolean, + }) + .unwrap() + }); } fn bench_json_get_str(b: &mut Bencher) { @@ -27,7 +36,15 @@ fn bench_json_get_str(b: &mut Bencher) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - b.iter(|| json_get_str.invoke_batch(args, 1).unwrap()); + b.iter(|| { + json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::Utf8, + }) + .unwrap() + }); } fn criterion_benchmark(c: &mut Criterion) { diff --git a/src/json_as_text.rs b/src/json_as_text.rs index ce73375..4c3fdba 100644 --- a/src/json_as_text.rs +++ b/src/json_as_text.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, StringArray, StringBuilder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; @@ -49,8 +49,8 @@ impl ScalarUDFImpl for JsonAsText { return_type_check(arg_types, self.name(), DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_as_text) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_as_text) } fn aliases(&self) -> &[String] { diff --git a/src/json_contains.rs b/src/json_contains.rs index 0eda035..a1ab056 100644 --- a/src/json_contains.rs +++ b/src/json_contains.rs @@ -5,7 +5,7 @@ use datafusion::arrow::array::BooleanBuilder; use datafusion::arrow::datatypes::DataType; use datafusion::common::arrow::array::{ArrayRef, BooleanArray}; use datafusion::common::{plan_err, Result, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use crate::common::{invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; use crate::common_macros::make_udf_function; @@ -53,8 +53,8 @@ impl ScalarUDFImpl for JsonContains { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - invoke::(args, jiter_json_contains) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + invoke::(&args.args, jiter_json_contains) } fn aliases(&self) -> &[String] { diff --git a/src/json_get.rs b/src/json_get.rs index c0ba131..097bae2 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -5,7 +5,7 @@ use datafusion::arrow::array::ArrayRef; use datafusion::arrow::array::UnionArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion::scalar::ScalarValue; use jiter::{Jiter, NumberAny, NumberInt, Peek}; @@ -55,8 +55,8 @@ impl ScalarUDFImpl for JsonGet { return_type_check(arg_types, self.name(), JsonUnion::data_type()) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_union) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_union) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs index 2bae88b..4cb4560 100644 --- a/src/json_get_bool.rs +++ b/src/json_get_bool.rs @@ -3,7 +3,7 @@ use std::any::Any; use datafusion::arrow::array::BooleanArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; @@ -48,8 +48,8 @@ impl ScalarUDFImpl for JsonGetBool { return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_bool) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_bool) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_float.rs b/src/json_get_float.rs index 24b00a5..34e0247 100644 --- a/src/json_get_float.rs +++ b/src/json_get_float.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, Float64Array, Float64Builder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberAny, Peek}; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; @@ -49,8 +49,8 @@ impl ScalarUDFImpl for JsonGetFloat { return_type_check(arg_types, self.name(), DataType::Float64) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_float) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_float) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_int.rs b/src/json_get_int.rs index a6a2b90..26d24ec 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, Int64Array, Int64Builder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberInt, Peek}; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; @@ -49,8 +49,8 @@ impl ScalarUDFImpl for JsonGetInt { return_type_check(arg_types, self.name(), DataType::Int64) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_int) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_int) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_json.rs b/src/json_get_json.rs index c5890b4..a8c6477 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -3,7 +3,7 @@ use std::any::Any; use datafusion::arrow::array::StringArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; @@ -47,8 +47,8 @@ impl ScalarUDFImpl for JsonGetJson { return_type_check(arg_types, self.name(), DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_json) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_json) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_str.rs b/src/json_get_str.rs index 633321f..e8ee200 100644 --- a/src/json_get_str.rs +++ b/src/json_get_str.rs @@ -3,7 +3,7 @@ use std::any::Any; use datafusion::arrow::array::StringArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; @@ -48,8 +48,8 @@ impl ScalarUDFImpl for JsonGetStr { return_type_check(arg_types, self.name(), DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_str) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_str) } fn aliases(&self) -> &[String] { diff --git a/src/json_length.rs b/src/json_length.rs index 8d0bfd3..f478854 100644 --- a/src/json_length.rs +++ b/src/json_length.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, UInt64Array, UInt64Builder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; @@ -49,8 +49,8 @@ impl ScalarUDFImpl for JsonLength { return_type_check(arg_types, self.name(), DataType::UInt64) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_length) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_length) } fn aliases(&self) -> &[String] { diff --git a/src/json_object_keys.rs b/src/json_object_keys.rs index 04cac0e..a07cec0 100644 --- a/src/json_object_keys.rs +++ b/src/json_object_keys.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, ListBuilder, StringBuilder}; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; @@ -53,8 +53,8 @@ impl ScalarUDFImpl for JsonObjectKeys { ) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_object_keys) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_object_keys) } fn aliases(&self) -> &[String] { diff --git a/tests/main.rs b/tests/main.rs index 0e65338..f591385 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -5,7 +5,7 @@ use datafusion::arrow::datatypes::{Field, Int64Type, Int8Type, Schema}; use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType}; use datafusion::assert_batches_eq; use datafusion::common::ScalarValue; -use datafusion::logical_expr::ColumnarValue; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion::prelude::SessionContext; use datafusion_functions_json::udfs::json_get_str_udf; use utils::{create_context, display_val, logical_plan, run_query, run_query_dict, run_query_large, run_query_params}; @@ -500,7 +500,14 @@ fn test_json_get_utf8() { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - let ColumnarValue::Scalar(sv) = json_get_str.invoke_batch(args, 1).unwrap() else { + let ColumnarValue::Scalar(sv) = json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::Utf8, + }) + .unwrap() + else { panic!("expected scalar") }; @@ -518,7 +525,14 @@ fn test_json_get_large_utf8() { ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("aa".to_string()))), ]; - let ColumnarValue::Scalar(sv) = json_get_str.invoke_batch(args, 1).unwrap() else { + let ColumnarValue::Scalar(sv) = json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::LargeUtf8, + }) + .unwrap() + else { panic!("expected scalar") }; From a739fd706053ef134e25f3460cdaee59dd7f46a3 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:33:33 -0500 Subject: [PATCH 8/8] update --- Cargo.toml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index de85981..4dbb883 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,11 +34,3 @@ pedantic = { level = "deny", priority = -1 } [[bench]] name = "main" harness = false - -[patch.crates-io] -datafusion = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } -datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } -datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } -datafusion-functions-aggregate = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } -datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } -datafusion-macros = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" }