From aff3e71318f1db8ab0e8851d1cc1e26693f6ff9a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 20 Mar 2025 13:01:52 -0500 Subject: [PATCH] Prepare for DataFusion 47 --- Cargo.toml | 8 ++++++++ benches/main.rs | 25 +++++++++++++++++++++---- src/json_as_text.rs | 6 +++--- src/json_contains.rs | 6 +++--- src/json_get.rs | 6 +++--- src/json_get_bool.rs | 6 +++--- src/json_get_float.rs | 6 +++--- src/json_get_int.rs | 6 +++--- src/json_get_json.rs | 6 +++--- src/json_get_str.rs | 6 +++--- src/json_length.rs | 6 +++--- src/json_object_keys.rs | 6 +++--- tests/main.rs | 20 +++++++++++++++++--- 13 files changed, 76 insertions(+), 37 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4dbb883..de85981 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,11 @@ pedantic = { level = "deny", priority = -1 } [[bench]] name = "main" harness = false + +[patch.crates-io] +datafusion = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } +datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } +datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } +datafusion-functions-aggregate = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } +datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } +datafusion-macros = { git = "https://github.com/apache/datafusion.git", rev = "722ccb9" } diff --git a/benches/main.rs b/benches/main.rs index c12c003..5c3eac5 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -1,12 +1,13 @@ use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criterion}; -use datafusion::common::ScalarValue; +use datafusion::arrow::datatypes::DataType; use datafusion::logical_expr::ColumnarValue; +use datafusion::{common::ScalarValue, logical_expr::ScalarFunctionArgs}; use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf}; fn bench_json_contains(b: &mut Bencher) { let json_contains = json_contains_udf(); - let args = &[ + let args = vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some( r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(), ))), @@ -14,7 +15,15 @@ fn bench_json_contains(b: &mut Bencher) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - b.iter(|| json_contains.invoke_batch(args, 1).unwrap()); + b.iter(|| { + json_contains + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + number_rows: 1, + return_type: &DataType::Boolean, + }) + .unwrap() + }); } fn bench_json_get_str(b: &mut Bencher) { @@ -27,7 +36,15 @@ fn bench_json_get_str(b: &mut Bencher) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - b.iter(|| json_get_str.invoke_batch(args, 1).unwrap()); + b.iter(|| { + json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::Utf8, + }) + .unwrap() + }); } fn criterion_benchmark(c: &mut Criterion) { diff --git a/src/json_as_text.rs b/src/json_as_text.rs index ce73375..4c3fdba 100644 --- a/src/json_as_text.rs +++ b/src/json_as_text.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, StringArray, StringBuilder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; @@ -49,8 +49,8 @@ impl ScalarUDFImpl for JsonAsText { return_type_check(arg_types, self.name(), DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_as_text) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_as_text) } fn aliases(&self) -> &[String] { diff --git a/src/json_contains.rs b/src/json_contains.rs index 0eda035..a1ab056 100644 --- a/src/json_contains.rs +++ b/src/json_contains.rs @@ -5,7 +5,7 @@ use datafusion::arrow::array::BooleanBuilder; use datafusion::arrow::datatypes::DataType; use datafusion::common::arrow::array::{ArrayRef, BooleanArray}; use datafusion::common::{plan_err, Result, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use crate::common::{invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; use crate::common_macros::make_udf_function; @@ -53,8 +53,8 @@ impl ScalarUDFImpl for JsonContains { } } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - invoke::(args, jiter_json_contains) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + invoke::(&args.args, jiter_json_contains) } fn aliases(&self) -> &[String] { diff --git a/src/json_get.rs b/src/json_get.rs index c0ba131..097bae2 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -5,7 +5,7 @@ use datafusion::arrow::array::ArrayRef; use datafusion::arrow::array::UnionArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion::scalar::ScalarValue; use jiter::{Jiter, NumberAny, NumberInt, Peek}; @@ -55,8 +55,8 @@ impl ScalarUDFImpl for JsonGet { return_type_check(arg_types, self.name(), JsonUnion::data_type()) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_union) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_union) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs index 2bae88b..4cb4560 100644 --- a/src/json_get_bool.rs +++ b/src/json_get_bool.rs @@ -3,7 +3,7 @@ use std::any::Any; use datafusion::arrow::array::BooleanArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; @@ -48,8 +48,8 @@ impl ScalarUDFImpl for JsonGetBool { return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_bool) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_bool) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_float.rs b/src/json_get_float.rs index 24b00a5..34e0247 100644 --- a/src/json_get_float.rs +++ b/src/json_get_float.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, Float64Array, Float64Builder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberAny, Peek}; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; @@ -49,8 +49,8 @@ impl ScalarUDFImpl for JsonGetFloat { return_type_check(arg_types, self.name(), DataType::Float64) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_float) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_float) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_int.rs b/src/json_get_int.rs index a6a2b90..26d24ec 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, Int64Array, Int64Builder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberInt, Peek}; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; @@ -49,8 +49,8 @@ impl ScalarUDFImpl for JsonGetInt { return_type_check(arg_types, self.name(), DataType::Int64) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_int) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_int) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_json.rs b/src/json_get_json.rs index c5890b4..a8c6477 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -3,7 +3,7 @@ use std::any::Any; use datafusion::arrow::array::StringArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; use crate::common_macros::make_udf_function; @@ -47,8 +47,8 @@ impl ScalarUDFImpl for JsonGetJson { return_type_check(arg_types, self.name(), DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_json) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_json) } fn aliases(&self) -> &[String] { diff --git a/src/json_get_str.rs b/src/json_get_str.rs index 633321f..e8ee200 100644 --- a/src/json_get_str.rs +++ b/src/json_get_str.rs @@ -3,7 +3,7 @@ use std::any::Any; use datafusion::arrow::array::StringArray; use datafusion::arrow::datatypes::DataType; use datafusion::common::Result as DataFusionResult; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; @@ -48,8 +48,8 @@ impl ScalarUDFImpl for JsonGetStr { return_type_check(arg_types, self.name(), DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_get_str) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_get_str) } fn aliases(&self) -> &[String] { diff --git a/src/json_length.rs b/src/json_length.rs index 8d0bfd3..f478854 100644 --- a/src/json_length.rs +++ b/src/json_length.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, UInt64Array, UInt64Builder}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; @@ -49,8 +49,8 @@ impl ScalarUDFImpl for JsonLength { return_type_check(arg_types, self.name(), DataType::UInt64) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_length) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_length) } fn aliases(&self) -> &[String] { diff --git a/src/json_object_keys.rs b/src/json_object_keys.rs index 04cac0e..a07cec0 100644 --- a/src/json_object_keys.rs +++ b/src/json_object_keys.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, ListBuilder, StringBuilder}; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::common::{Result as DataFusionResult, ScalarValue}; -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath}; @@ -53,8 +53,8 @@ impl ScalarUDFImpl for JsonObjectKeys { ) } - fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::(args, jiter_json_object_keys) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + invoke::(&args.args, jiter_json_object_keys) } fn aliases(&self) -> &[String] { diff --git a/tests/main.rs b/tests/main.rs index 0e65338..f591385 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -5,7 +5,7 @@ use datafusion::arrow::datatypes::{Field, Int64Type, Int8Type, Schema}; use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType}; use datafusion::assert_batches_eq; use datafusion::common::ScalarValue; -use datafusion::logical_expr::ColumnarValue; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion::prelude::SessionContext; use datafusion_functions_json::udfs::json_get_str_udf; use utils::{create_context, display_val, logical_plan, run_query, run_query_dict, run_query_large, run_query_params}; @@ -500,7 +500,14 @@ fn test_json_get_utf8() { ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))), ]; - let ColumnarValue::Scalar(sv) = json_get_str.invoke_batch(args, 1).unwrap() else { + let ColumnarValue::Scalar(sv) = json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::Utf8, + }) + .unwrap() + else { panic!("expected scalar") }; @@ -518,7 +525,14 @@ fn test_json_get_large_utf8() { ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("aa".to_string()))), ]; - let ColumnarValue::Scalar(sv) = json_get_str.invoke_batch(args, 1).unwrap() else { + let ColumnarValue::Scalar(sv) = json_get_str + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + number_rows: 1, + return_type: &DataType::LargeUtf8, + }) + .unwrap() + else { panic!("expected scalar") };