Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add variants of functions optimized for sorted keys #77

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 91 additions & 5 deletions benches/main.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
use codspeed_criterion_compat::{criterion_group, criterion_main, Bencher, Criterion};

use datafusion::common::ScalarValue;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_udf};
use datafusion::{common::ScalarValue, logical_expr::ScalarFunctionArgs};
use datafusion_functions_json::udfs::{json_contains_udf, json_get_str_top_level_sorted_udf, json_get_str_udf};

fn bench_json_contains(b: &mut Bencher) {
let json_contains = json_contains_udf();
let args = &[
let args = vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(),
))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))),
];

b.iter(|| json_contains.invoke_batch(args, 1).unwrap());
b.iter(|| {
json_contains
.invoke_with_args(ScalarFunctionArgs {
args: args.clone(),
number_rows: 1,
return_type: &DataType::Boolean,
})
.unwrap()
});
}

fn bench_json_get_str(b: &mut Bencher) {
Expand All @@ -27,12 +36,89 @@ fn bench_json_get_str(b: &mut Bencher) {
ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))),
];

b.iter(|| json_get_str.invoke_batch(args, 1).unwrap());
b.iter(|| {
json_get_str
.invoke_with_args(ScalarFunctionArgs {
args: args.to_vec(),
number_rows: 1,
return_type: &DataType::Utf8,
})
.unwrap()
});
}

fn make_json_negative_testcase() -> String {
// build a json with keys "b1", "b2" ... "b100", each with a large value ("a" repeated 1024 times)
let kvs = (0..100)
.map(|i| format!(r#""b{}": "{}""#, i, "a".repeat(1024)))
.collect::<Vec<_>>()
.join(",");
format!(r"{{ {kvs} }}")
}

fn bench_json_get_str_negative(b: &mut Bencher) {
let json_get_str = json_get_str_udf();
let args = &[
ColumnarValue::Scalar(ScalarValue::Utf8(Some(make_json_negative_testcase()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), // lexicographically less than "b1"
];

b.iter(|| {
json_get_str
.invoke_with_args(ScalarFunctionArgs {
args: args.to_vec(),
number_rows: 1,
return_type: &DataType::Utf8,
})
.unwrap()
});
}

fn bench_json_get_str_sorted(b: &mut Bencher) {
let json_get_str = json_get_str_top_level_sorted_udf();
let args = &[
ColumnarValue::Scalar(ScalarValue::Utf8(Some(
r#"{"a": {"aa": "x", "ab: "y"}, "b": []}"#.to_string(),
))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("aa".to_string()))),
];

b.iter(|| {
json_get_str
.invoke_with_args(ScalarFunctionArgs {
args: args.to_vec(),
number_rows: 1,
return_type: &DataType::Utf8,
})
.unwrap()
});
}

fn bench_json_get_str_sorted_negative(b: &mut Bencher) {
let json_get_str = json_get_str_top_level_sorted_udf();
let args = &[
ColumnarValue::Scalar(ScalarValue::Utf8(Some(make_json_negative_testcase()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))), // lexicographically less than "b1"
];

b.iter(|| {
json_get_str
.invoke_with_args(ScalarFunctionArgs {
args: args.to_vec(),
number_rows: 1,
return_type: &DataType::Utf8,
})
.unwrap()
});
}

fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("json_contains", bench_json_contains);
c.bench_function("json_get_str", bench_json_get_str);
c.bench_function("json_get_str_negative", bench_json_get_str_negative);
c.bench_function("json_get_str_sorted", bench_json_get_str_sorted);
c.bench_function("json_get_str_sorted_negative", bench_json_get_str_sorted_negative);
}

criterion_group!(benches, criterion_benchmark);
Expand Down
46 changes: 45 additions & 1 deletion src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,11 @@ fn wrap_as_large_dictionary(original: &dyn AnyDictionaryArray, new_values: Array
DictionaryArray::new(keys, new_values)
}

pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> {
pub fn jiter_json_find<'j>(
opt_json: Option<&'j str>,
path: &[JsonPath],
mut sorted: Sortedness,
) -> Option<(Jiter<'j>, Peek)> {
let json_str = opt_json?;
let mut jiter = Jiter::new(json_str.as_bytes());
let mut peek = jiter.peek().ok()?;
Expand All @@ -521,6 +525,11 @@ pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Opti
let mut next_key = jiter.known_object().ok()??;

while next_key != *key {
if next_key > *key && matches!(sorted, Sortedness::Recursive | Sortedness::TopLevel) {
// The current object is sorted and next_key is lexicographically greater than key
// we are looking for, so we can early stop here.
return None;
}
jiter.next_skip().ok()?;
next_key = jiter.next_key().ok()??;
}
Expand All @@ -541,6 +550,9 @@ pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Opti
return None;
}
}
if sorted == Sortedness::TopLevel {
sorted = Sortedness::Unspecified;
}
}
Some((jiter, peek))
}
Expand Down Expand Up @@ -585,3 +597,35 @@ fn mask_dictionary_keys(keys: &PrimitiveArray<Int64Type>, type_ids: &[i8]) -> Pr
}
PrimitiveArray::new(keys.values().clone(), Some(null_mask.into()))
}

/// Information about the sortedness of a JSON object.
/// This is used to optimize key lookups by early stopping when the key we are looking for is
/// lexicographically greater than the current key and the object is known to be sorted.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) enum Sortedness {
/// No guarantees about the order of the elements.
Unspecified,
/// Only the outermost object is known to be sorted.
/// If the outermost item is not an object, this is equivalent to `Unspecified`.
TopLevel,
/// All objects are known to be sorted, including objects nested within arrays.
Recursive,
}

impl Sortedness {
pub(crate) fn iter() -> impl Iterator<Item = Self> {
[Sortedness::Unspecified, Sortedness::TopLevel, Sortedness::Recursive]
.iter()
.copied()
}
}

impl Sortedness {
pub(crate) fn function_name_suffix(self) -> &'static str {
match self {
Sortedness::Unspecified => "",
Sortedness::TopLevel => "_top_level_sorted",
Sortedness::Recursive => "_recursive_sorted",
}
}
}
4 changes: 2 additions & 2 deletions src/common_macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
///
/// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl
macro_rules! make_udf_function {
($udf_impl:ty, $expr_fn_name:ident, $($arg:ident)*, $doc:expr) => {
($udf_impl:ty, $expr_fn_name:ident, $($arg:ident)*, $doc:expr, $sorted:expr) => {
paste::paste! {
#[doc = $doc]
#[must_use] pub fn $expr_fn_name($($arg: datafusion::logical_expr::Expr),*) -> datafusion::logical_expr::Expr {
Expand All @@ -37,7 +37,7 @@ macro_rules! make_udf_function {
[< STATIC_ $expr_fn_name:upper >]
.get_or_init(|| {
std::sync::Arc::new(datafusion::logical_expr::ScalarUDF::new_from_impl(
<$udf_impl>::default(),
<$udf_impl>::new($sorted),
))
})
.clone()
Expand Down
41 changes: 31 additions & 10 deletions src/json_as_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,51 @@ use std::sync::Arc;
use datafusion::arrow::array::{ArrayRef, StringArray, StringBuilder};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use jiter::Peek;

use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath};
use crate::common::{
get_err, invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath, Sortedness,
};
use crate::common_macros::make_udf_function;

make_udf_function!(
JsonAsText,
json_as_text,
json_data path,
r#"Get any value from a JSON string by its "path", represented as a string"#
r#"Get any value from a JSON string by its "path", represented as a string"#,
Sortedness::Unspecified
);

make_udf_function!(
JsonAsText,
json_as_text_top_level_sorted,
json_data path,
r#"Get any value from a JSON string by its "path", represented as a string; assumes the JSON string's top level object's keys are sorted."#,
Sortedness::TopLevel
);

make_udf_function!(
JsonAsText,
json_as_text_recursive_sorted,
json_data path,
r#"Get any value from a JSON string by its "path", represented as a string; assumes all json object's keys are sorted."#,
Sortedness::Recursive
);

#[derive(Debug)]
pub(super) struct JsonAsText {
signature: Signature,
aliases: [String; 1],
sorted: Sortedness,
}

impl Default for JsonAsText {
fn default() -> Self {
impl JsonAsText {
pub fn new(sorted: Sortedness) -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: ["json_as_text".to_string()],
aliases: [format!("json_as_text{}", sorted.function_name_suffix())],
sorted,
}
}
}
Expand All @@ -49,8 +70,8 @@ impl ScalarUDFImpl for JsonAsText {
return_type_check(arg_types, self.name(), DataType::Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
invoke::<StringArray>(args, jiter_json_as_text)
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
invoke::<StringArray>(&args.args, |args, path| jiter_json_as_text(args, path, self.sorted))
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -82,8 +103,8 @@ impl InvokeResult for StringArray {
}
}

fn jiter_json_as_text(opt_json: Option<&str>, path: &[JsonPath]) -> Result<String, GetError> {
if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) {
fn jiter_json_as_text(opt_json: Option<&str>, path: &[JsonPath], sorted: Sortedness) -> Result<String, GetError> {
if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path, sorted) {
match peek {
Peek::Null => {
jiter.known_null()?;
Expand Down
41 changes: 32 additions & 9 deletions src/json_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use datafusion::arrow::array::BooleanBuilder;
use datafusion::arrow::datatypes::DataType;
use datafusion::common::arrow::array::{ArrayRef, BooleanArray};
use datafusion::common::{plan_err, Result, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};

use crate::common::{invoke, jiter_json_find, return_type_check, GetError, InvokeResult, JsonPath};
use crate::common_macros::make_udf_function;
Expand All @@ -14,20 +14,39 @@ make_udf_function!(
JsonContains,
json_contains,
json_data path,
r#"Does the key/index exist within the JSON value as the specified "path"?"#
r#"Does the key/index exist within the JSON value as the specified "path"?"#,
crate::common::Sortedness::Unspecified
);

make_udf_function!(
JsonContains,
json_contains_top_level_sorted,
json_data path,
r#"Does the key/index exist within the JSON value as the specified "path"; assumes the JSON string's top level object's keys are sorted?"#,
crate::common::Sortedness::TopLevel
);

make_udf_function!(
JsonContains,
json_contains_recursive_sorted,
json_data path,
r#"Does the key/index exist within the JSON value as the specified "path"; assumes all json object's keys are sorted?"#,
crate::common::Sortedness::Recursive
);

#[derive(Debug)]
pub(super) struct JsonContains {
signature: Signature,
aliases: [String; 1],
sorted: crate::common::Sortedness,
}

impl Default for JsonContains {
fn default() -> Self {
impl JsonContains {
pub fn new(sorted: crate::common::Sortedness) -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: ["json_contains".to_string()],
aliases: [format!("json_contains{}", sorted.function_name_suffix())],
sorted,
}
}
}
Expand All @@ -53,8 +72,8 @@ impl ScalarUDFImpl for JsonContains {
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
invoke::<BooleanArray>(args, jiter_json_contains)
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
invoke::<BooleanArray>(&args.args, |json, path| jiter_json_contains(json, path, self.sorted))
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -88,6 +107,10 @@ impl InvokeResult for BooleanArray {
}

#[allow(clippy::unnecessary_wraps)]
fn jiter_json_contains(json_data: Option<&str>, path: &[JsonPath]) -> Result<bool, GetError> {
Ok(jiter_json_find(json_data, path).is_some())
fn jiter_json_contains(
json_data: Option<&str>,
path: &[JsonPath],
sorted: crate::common::Sortedness,
) -> Result<bool, GetError> {
Ok(jiter_json_find(json_data, path, sorted).is_some())
}
Loading