Skip to content

Commit a645a62

Browse files
authored
support arrow functions with ExprPlanner (#26)
1 parent 40a8090 commit a645a62

File tree

5 files changed

+371
-51
lines changed

5 files changed

+371
-51
lines changed

Cargo.toml

+13-6
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,29 @@ repository = "https://github.com/datafusion-contrib/datafusion-functions-json/"
1111
rust-version = "1.76.0"
1212

1313
[dependencies]
14-
arrow = "52"
15-
arrow-schema = "52"
16-
datafusion-common = "39"
17-
datafusion-expr = "39"
14+
arrow = "52.1.0"
15+
arrow-schema = "52.1.0"
16+
datafusion-common = "40"
17+
datafusion-expr = "40"
18+
datafusion-execution = "40"
1819
jiter = "0.5"
1920
paste = "1"
2021
log = "0.4"
21-
datafusion-execution = "39"
2222

2323
[dev-dependencies]
2424
codspeed-criterion-compat = "2.3"
2525
criterion = "0.5.1"
26-
datafusion = "39"
26+
datafusion = "40"
2727
clap = "4"
2828
tokio = { version = "1.37", features = ["full"] }
2929

30+
[patch.crates-io]
31+
# TODO: remove this once datafusion 40.0 is released
32+
datafusion = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" }
33+
datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" }
34+
datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" }
35+
datafusion-execution = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" }
36+
3037
[lints.clippy]
3138
dbg_macro = "deny"
3239
print_stdout = "deny"

src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
6767
Ok(()) as Result<()>
6868
})?;
6969
registry.register_function_rewrite(Arc::new(rewrite::JsonFunctionRewriter))?;
70+
registry.register_expr_planner(Arc::new(rewrite::JsonExprPlanner))?;
7071

7172
Ok(())
7273
}

src/rewrite.rs

+68-35
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ use datafusion_common::config::ConfigOptions;
33
use datafusion_common::tree_node::Transformed;
44
use datafusion_common::DFSchema;
55
use datafusion_common::Result;
6-
use datafusion_expr::expr::ScalarFunction;
6+
use datafusion_expr::expr::{Alias, Cast, Expr, ScalarFunction};
77
use datafusion_expr::expr_rewriter::FunctionRewrite;
8-
use datafusion_expr::Expr;
8+
use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr};
9+
use datafusion_expr::sqlparser::ast::BinaryOperator;
910

1011
pub(crate) struct JsonFunctionRewriter;
1112

@@ -15,25 +16,37 @@ impl FunctionRewrite for JsonFunctionRewriter {
1516
}
1617

1718
fn rewrite(&self, expr: Expr, _schema: &DFSchema, _config: &ConfigOptions) -> Result<Transformed<Expr>> {
18-
if let Expr::Cast(cast) = &expr {
19-
if let Expr::ScalarFunction(func) = &*cast.expr {
20-
if func.func.name() == "json_get" {
21-
if let Some(t) = switch_json_get(&cast.data_type, &func.args) {
22-
return Ok(t);
23-
}
24-
}
25-
}
26-
} else if let Expr::ScalarFunction(func) = &expr {
27-
if let Some(new_func) = unnest_json_calls(func) {
28-
return Ok(Transformed::yes(Expr::ScalarFunction(new_func)));
29-
}
30-
}
31-
Ok(Transformed::no(expr))
19+
let transform = match &expr {
20+
Expr::Cast(cast) => optimise_json_get_cast(cast),
21+
Expr::ScalarFunction(func) => unnest_json_calls(func),
22+
_ => None,
23+
};
24+
Ok(transform.unwrap_or_else(|| Transformed::no(expr)))
3225
}
3326
}
3427

28+
/// This replaces `get_json(foo, bar)::int` with `json_get_int(foo, bar)` so the JSON function can take care of
29+
/// extracting the right value type from JSON without the need to materialize the JSON union.
30+
fn optimise_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
31+
let scalar_func = extract_scalar_function(&cast.expr)?;
32+
if scalar_func.func.name() != "json_get" {
33+
return None;
34+
}
35+
let func = match &cast.data_type {
36+
DataType::Boolean => crate::json_get_bool::json_get_bool_udf(),
37+
DataType::Float64 | DataType::Float32 => crate::json_get_float::json_get_float_udf(),
38+
DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(),
39+
DataType::Utf8 => crate::json_get_str::json_get_str_udf(),
40+
_ => return None,
41+
};
42+
Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
43+
func,
44+
args: scalar_func.args.clone(),
45+
})))
46+
}
47+
3548
// Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')`
36-
fn unnest_json_calls(func: &ScalarFunction) -> Option<ScalarFunction> {
49+
fn unnest_json_calls(func: &ScalarFunction) -> Option<Transformed<Expr>> {
3750
if !matches!(
3851
func.func.name(),
3952
"json_get" | "json_get_bool" | "json_get_float" | "json_get_int" | "json_get_json" | "json_get_str"
@@ -42,9 +55,7 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option<ScalarFunction> {
4255
}
4356
let mut outer_args_iter = func.args.iter();
4457
let first_arg = outer_args_iter.next()?;
45-
let Expr::ScalarFunction(inner_func) = first_arg else {
46-
return None;
47-
};
58+
let inner_func = extract_scalar_function(first_arg)?;
4859
if inner_func.func.name() != "json_get" {
4960
return None;
5061
}
@@ -53,26 +64,48 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option<ScalarFunction> {
5364
args.extend(outer_args_iter.cloned());
5465
// See #23, unnest only when all lookup arguments are literals
5566
if args.iter().skip(1).all(|arg| matches!(arg, Expr::Literal(_))) {
56-
Some(ScalarFunction {
67+
Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
5768
func: func.func.clone(),
5869
args,
59-
})
70+
})))
6071
} else {
6172
None
6273
}
6374
}
6475

65-
fn switch_json_get(cast_data_type: &DataType, args: &[Expr]) -> Option<Transformed<Expr>> {
66-
let func = match cast_data_type {
67-
DataType::Boolean => crate::json_get_bool::json_get_bool_udf(),
68-
DataType::Float64 | DataType::Float32 => crate::json_get_float::json_get_float_udf(),
69-
DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(),
70-
DataType::Utf8 => crate::json_get_str::json_get_str_udf(),
71-
_ => return None,
72-
};
73-
let f = ScalarFunction {
74-
func,
75-
args: args.to_vec(),
76-
};
77-
Some(Transformed::yes(Expr::ScalarFunction(f)))
76+
fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
77+
match expr {
78+
Expr::ScalarFunction(func) => Some(func),
79+
Expr::Alias(alias) => extract_scalar_function(&*alias.expr),
80+
_ => None,
81+
}
82+
}
83+
84+
/// Implement a custom SQL planner to replace postgres JSON operators with custom UDFs
85+
#[derive(Debug, Default)]
86+
pub struct JsonExprPlanner;
87+
88+
impl ExprPlanner for JsonExprPlanner {
89+
fn plan_binary_op(&self, expr: RawBinaryExpr, _schema: &DFSchema) -> Result<PlannerResult<RawBinaryExpr>> {
90+
let (func, op_display) = match &expr.op {
91+
BinaryOperator::Arrow => (crate::json_get::json_get_udf(), "->"),
92+
BinaryOperator::LongArrow => (crate::json_get_str::json_get_str_udf(), "->>"),
93+
BinaryOperator::Question => (crate::json_contains::json_contains_udf(), "?"),
94+
_ => return Ok(PlannerResult::Original(expr)),
95+
};
96+
let alias_name = match &expr.left {
97+
Expr::Alias(alias) => format!("{} {} {}", alias.name, op_display, expr.right),
98+
left_expr => format!("{} {} {}", left_expr, op_display, expr.right),
99+
};
100+
101+
// we put the alias in so that default column titles are `foo -> bar` instead of `json_get(foo, bar)`
102+
Ok(PlannerResult::Planned(Expr::Alias(Alias::new(
103+
Expr::ScalarFunction(ScalarFunction {
104+
func,
105+
args: vec![expr.left, expr.right],
106+
}),
107+
None::<&str>,
108+
alias_name,
109+
))))
110+
}
78111
}

0 commit comments

Comments
 (0)