@@ -3,9 +3,10 @@ use datafusion_common::config::ConfigOptions;
3
3
use datafusion_common:: tree_node:: Transformed ;
4
4
use datafusion_common:: DFSchema ;
5
5
use datafusion_common:: Result ;
6
- use datafusion_expr:: expr:: ScalarFunction ;
6
+ use datafusion_expr:: expr:: { Alias , Cast , Expr , ScalarFunction } ;
7
7
use datafusion_expr:: expr_rewriter:: FunctionRewrite ;
8
- use datafusion_expr:: Expr ;
8
+ use datafusion_expr:: planner:: { ExprPlanner , PlannerResult , RawBinaryExpr } ;
9
+ use datafusion_expr:: sqlparser:: ast:: BinaryOperator ;
9
10
10
11
pub ( crate ) struct JsonFunctionRewriter ;
11
12
@@ -15,25 +16,37 @@ impl FunctionRewrite for JsonFunctionRewriter {
15
16
}
16
17
17
18
fn rewrite ( & self , expr : Expr , _schema : & DFSchema , _config : & ConfigOptions ) -> Result < Transformed < Expr > > {
18
- if let Expr :: Cast ( cast) = & expr {
19
- if let Expr :: ScalarFunction ( func) = & * cast. expr {
20
- if func. func . name ( ) == "json_get" {
21
- if let Some ( t) = switch_json_get ( & cast. data_type , & func. args ) {
22
- return Ok ( t) ;
23
- }
24
- }
25
- }
26
- } else if let Expr :: ScalarFunction ( func) = & expr {
27
- if let Some ( new_func) = unnest_json_calls ( func) {
28
- return Ok ( Transformed :: yes ( Expr :: ScalarFunction ( new_func) ) ) ;
29
- }
30
- }
31
- Ok ( Transformed :: no ( expr) )
19
+ let transform = match & expr {
20
+ Expr :: Cast ( cast) => optimise_json_get_cast ( cast) ,
21
+ Expr :: ScalarFunction ( func) => unnest_json_calls ( func) ,
22
+ _ => None ,
23
+ } ;
24
+ Ok ( transform. unwrap_or_else ( || Transformed :: no ( expr) ) )
32
25
}
33
26
}
34
27
28
+ /// This replaces `get_json(foo, bar)::int` with `json_get_int(foo, bar)` so the JSON function can take care of
29
+ /// extracting the right value type from JSON without the need to materialize the JSON union.
30
+ fn optimise_json_get_cast ( cast : & Cast ) -> Option < Transformed < Expr > > {
31
+ let scalar_func = extract_scalar_function ( & cast. expr ) ?;
32
+ if scalar_func. func . name ( ) != "json_get" {
33
+ return None ;
34
+ }
35
+ let func = match & cast. data_type {
36
+ DataType :: Boolean => crate :: json_get_bool:: json_get_bool_udf ( ) ,
37
+ DataType :: Float64 | DataType :: Float32 => crate :: json_get_float:: json_get_float_udf ( ) ,
38
+ DataType :: Int64 | DataType :: Int32 => crate :: json_get_int:: json_get_int_udf ( ) ,
39
+ DataType :: Utf8 => crate :: json_get_str:: json_get_str_udf ( ) ,
40
+ _ => return None ,
41
+ } ;
42
+ Some ( Transformed :: yes ( Expr :: ScalarFunction ( ScalarFunction {
43
+ func,
44
+ args : scalar_func. args . clone ( ) ,
45
+ } ) ) )
46
+ }
47
+
35
48
// Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')`
36
- fn unnest_json_calls ( func : & ScalarFunction ) -> Option < ScalarFunction > {
49
+ fn unnest_json_calls ( func : & ScalarFunction ) -> Option < Transformed < Expr > > {
37
50
if !matches ! (
38
51
func. func. name( ) ,
39
52
"json_get" | "json_get_bool" | "json_get_float" | "json_get_int" | "json_get_json" | "json_get_str"
@@ -42,9 +55,7 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option<ScalarFunction> {
42
55
}
43
56
let mut outer_args_iter = func. args . iter ( ) ;
44
57
let first_arg = outer_args_iter. next ( ) ?;
45
- let Expr :: ScalarFunction ( inner_func) = first_arg else {
46
- return None ;
47
- } ;
58
+ let inner_func = extract_scalar_function ( first_arg) ?;
48
59
if inner_func. func . name ( ) != "json_get" {
49
60
return None ;
50
61
}
@@ -53,26 +64,48 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option<ScalarFunction> {
53
64
args. extend ( outer_args_iter. cloned ( ) ) ;
54
65
// See #23, unnest only when all lookup arguments are literals
55
66
if args. iter ( ) . skip ( 1 ) . all ( |arg| matches ! ( arg, Expr :: Literal ( _) ) ) {
56
- Some ( ScalarFunction {
67
+ Some ( Transformed :: yes ( Expr :: ScalarFunction ( ScalarFunction {
57
68
func : func. func . clone ( ) ,
58
69
args,
59
- } )
70
+ } ) ) )
60
71
} else {
61
72
None
62
73
}
63
74
}
64
75
65
- fn switch_json_get ( cast_data_type : & DataType , args : & [ Expr ] ) -> Option < Transformed < Expr > > {
66
- let func = match cast_data_type {
67
- DataType :: Boolean => crate :: json_get_bool:: json_get_bool_udf ( ) ,
68
- DataType :: Float64 | DataType :: Float32 => crate :: json_get_float:: json_get_float_udf ( ) ,
69
- DataType :: Int64 | DataType :: Int32 => crate :: json_get_int:: json_get_int_udf ( ) ,
70
- DataType :: Utf8 => crate :: json_get_str:: json_get_str_udf ( ) ,
71
- _ => return None ,
72
- } ;
73
- let f = ScalarFunction {
74
- func,
75
- args : args. to_vec ( ) ,
76
- } ;
77
- Some ( Transformed :: yes ( Expr :: ScalarFunction ( f) ) )
76
+ fn extract_scalar_function ( expr : & Expr ) -> Option < & ScalarFunction > {
77
+ match expr {
78
+ Expr :: ScalarFunction ( func) => Some ( func) ,
79
+ Expr :: Alias ( alias) => extract_scalar_function ( & * alias. expr ) ,
80
+ _ => None ,
81
+ }
82
+ }
83
+
84
+ /// Implement a custom SQL planner to replace postgres JSON operators with custom UDFs
85
+ #[ derive( Debug , Default ) ]
86
+ pub struct JsonExprPlanner ;
87
+
88
+ impl ExprPlanner for JsonExprPlanner {
89
+ fn plan_binary_op ( & self , expr : RawBinaryExpr , _schema : & DFSchema ) -> Result < PlannerResult < RawBinaryExpr > > {
90
+ let ( func, op_display) = match & expr. op {
91
+ BinaryOperator :: Arrow => ( crate :: json_get:: json_get_udf ( ) , "->" ) ,
92
+ BinaryOperator :: LongArrow => ( crate :: json_get_str:: json_get_str_udf ( ) , "->>" ) ,
93
+ BinaryOperator :: Question => ( crate :: json_contains:: json_contains_udf ( ) , "?" ) ,
94
+ _ => return Ok ( PlannerResult :: Original ( expr) ) ,
95
+ } ;
96
+ let alias_name = match & expr. left {
97
+ Expr :: Alias ( alias) => format ! ( "{} {} {}" , alias. name, op_display, expr. right) ,
98
+ left_expr => format ! ( "{} {} {}" , left_expr, op_display, expr. right) ,
99
+ } ;
100
+
101
+ // we put the alias in so that default column titles are `foo -> bar` instead of `json_get(foo, bar)`
102
+ Ok ( PlannerResult :: Planned ( Expr :: Alias ( Alias :: new (
103
+ Expr :: ScalarFunction ( ScalarFunction {
104
+ func,
105
+ args : vec ! [ expr. left, expr. right] ,
106
+ } ) ,
107
+ None :: < & str > ,
108
+ alias_name,
109
+ ) ) ) )
110
+ }
78
111
}
0 commit comments