Skip to content

Commit 40a8090

Browse files
authored
unnest json_get calls (#25)
1 parent 783b60b commit 40a8090

File tree

5 files changed

+196
-14
lines changed

5 files changed

+196
-14
lines changed

Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ license = "Apache-2.0"
88
keywords = ["datafusion", "JSON", "SQL"]
99
categories = ["database-implementations", "parsing"]
1010
repository = "https://github.com/datafusion-contrib/datafusion-functions-json/"
11-
rust-version = "1.73.0"
11+
rust-version = "1.76.0"
1212

1313
[dependencies]
1414
arrow = "52"
@@ -24,7 +24,7 @@ datafusion-execution = "39"
2424
codspeed-criterion-compat = "2.3"
2525
criterion = "0.5.1"
2626
datafusion = "39"
27-
clap = "~4.4" # for testing on MSRV 1.73
27+
clap = "4"
2828
tokio = { version = "1.37", features = ["full"] }
2929

3030
[lints.clippy]

src/common.rs

+12-6
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
7575
ColumnarValue::Array(json_array) => {
7676
let result_collect = match args.get(1) {
7777
Some(ColumnarValue::Array(a)) => {
78+
if args.len() > 2 {
79+
// TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23
80+
return exec_err!(
81+
"More than 1 path element is not supported when querying JSON using an array."
82+
);
83+
}
7884
if let Some(str_path_array) = a.as_any().downcast_ref::<StringArray>() {
7985
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
8086
zip_apply(json_array, paths, jiter_find, true)
@@ -114,28 +120,28 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
114120

115121
fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
116122
json_array: &ArrayRef,
117-
paths: P,
123+
path_array: P,
118124
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
119125
object_lookup: bool,
120126
) -> DataFusionResult<C> {
121127
if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
122-
Ok(zip_apply_iter(string_array.iter(), paths, jiter_find))
128+
Ok(zip_apply_iter(string_array.iter(), path_array, jiter_find))
123129
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
124-
Ok(zip_apply_iter(large_string_array.iter(), paths, jiter_find))
130+
Ok(zip_apply_iter(large_string_array.iter(), path_array, jiter_find))
125131
} else if let Some(string_array) = nested_json_array(json_array, object_lookup) {
126-
Ok(zip_apply_iter(string_array.iter(), paths, jiter_find))
132+
Ok(zip_apply_iter(string_array.iter(), path_array, jiter_find))
127133
} else {
128134
exec_err!("unexpected json array type {:?}", json_array.data_type())
129135
}
130136
}
131137

132138
fn zip_apply_iter<'a, 'j, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
133139
json_iter: impl Iterator<Item = Option<&'j str>>,
134-
paths: P,
140+
path_array: P,
135141
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
136142
) -> C {
137143
json_iter
138-
.zip(paths)
144+
.zip(path_array)
139145
.map(|(opt_json, opt_path)| {
140146
if let Some(path) = opt_path {
141147
jiter_find(opt_json, &[path]).ok()

src/rewrite.rs

+34
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,45 @@ impl FunctionRewrite for JsonFunctionRewriter {
2323
}
2424
}
2525
}
26+
} else if let Expr::ScalarFunction(func) = &expr {
27+
if let Some(new_func) = unnest_json_calls(func) {
28+
return Ok(Transformed::yes(Expr::ScalarFunction(new_func)));
29+
}
2630
}
2731
Ok(Transformed::no(expr))
2832
}
2933
}
3034

35+
// Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')`
36+
fn unnest_json_calls(func: &ScalarFunction) -> Option<ScalarFunction> {
37+
if !matches!(
38+
func.func.name(),
39+
"json_get" | "json_get_bool" | "json_get_float" | "json_get_int" | "json_get_json" | "json_get_str"
40+
) {
41+
return None;
42+
}
43+
let mut outer_args_iter = func.args.iter();
44+
let first_arg = outer_args_iter.next()?;
45+
let Expr::ScalarFunction(inner_func) = first_arg else {
46+
return None;
47+
};
48+
if inner_func.func.name() != "json_get" {
49+
return None;
50+
}
51+
52+
let mut args = inner_func.args.clone();
53+
args.extend(outer_args_iter.cloned());
54+
// See #23, unnest only when all lookup arguments are literals
55+
if args.iter().skip(1).all(|arg| matches!(arg, Expr::Literal(_))) {
56+
Some(ScalarFunction {
57+
func: func.func.clone(),
58+
args,
59+
})
60+
} else {
61+
None
62+
}
63+
}
64+
3165
fn switch_json_get(cast_data_type: &DataType, args: &[Expr]) -> Option<Transformed<Expr>> {
3266
let func = match cast_data_type {
3367
DataType::Boolean => crate::json_get_bool::json_get_bool_udf(),

tests/main.rs

+141-6
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use datafusion_common::ScalarValue;
55
mod utils;
66
use datafusion_expr::ColumnarValue;
77
use datafusion_functions_json::udfs::json_get_str_udf;
8-
use utils::{display_val, run_query, run_query_large, run_query_params};
8+
use utils::{display_val, logical_plan, run_query, run_query_large, run_query_params};
99

1010
#[tokio::test]
1111
async fn test_json_contains() {
@@ -523,7 +523,7 @@ async fn test_json_get_union_scalar() {
523523
}
524524

525525
#[tokio::test]
526-
async fn test_json_get_union_array() {
526+
async fn test_json_get_nested_collapsed() {
527527
let expected = [
528528
"+------------------+---------+",
529529
"| name | v |",
@@ -545,7 +545,130 @@ async fn test_json_get_union_array() {
545545
}
546546

547547
#[tokio::test]
548-
async fn test_json_get_union_array_skip() {
548+
async fn test_json_get_cte() {
549+
// avoid auto-un-nesting with a CTE
550+
let sql = r#"
551+
with t as (select name, json_get(json_data, 'foo') j from test)
552+
select name, json_get(j, 0) v from t
553+
"#;
554+
let expected = [
555+
"+------------------+---------+",
556+
"| name | v |",
557+
"+------------------+---------+",
558+
"| object_foo | {null=} |",
559+
"| object_foo_array | {int=1} |",
560+
"| object_foo_obj | {null=} |",
561+
"| object_foo_null | {null=} |",
562+
"| object_bar | {null=} |",
563+
"| list_foo | {null=} |",
564+
"| invalid_json | {null=} |",
565+
"+------------------+---------+",
566+
];
567+
568+
let batches = run_query(sql).await.unwrap();
569+
assert_batches_eq!(expected, &batches);
570+
}
571+
572+
#[tokio::test]
573+
async fn test_json_get_cte_plan() {
574+
// avoid auto-unnesting with a CTE
575+
let sql = r#"
576+
explain
577+
with t as (select name, json_get(json_data, 'foo') j from test)
578+
select name, json_get(j, 0) v from t
579+
"#;
580+
let expected = [
581+
"Projection: t.name, json_get(t.j, Int64(0)) AS v",
582+
" SubqueryAlias: t",
583+
" Projection: test.name, json_get(test.json_data, Utf8(\"foo\")) AS j",
584+
" TableScan: test projection=[name, json_data]",
585+
];
586+
587+
let plan_lines = logical_plan(sql).await;
588+
assert_eq!(plan_lines, expected);
589+
}
590+
591+
#[tokio::test]
592+
async fn test_json_get_unnest() {
593+
let sql = "select name, json_get(json_get(json_data, 'foo'), 0) v from test";
594+
595+
let expected = [
596+
"+------------------+---------+",
597+
"| name | v |",
598+
"+------------------+---------+",
599+
"| object_foo | {null=} |",
600+
"| object_foo_array | {int=1} |",
601+
"| object_foo_obj | {null=} |",
602+
"| object_foo_null | {null=} |",
603+
"| object_bar | {null=} |",
604+
"| list_foo | {null=} |",
605+
"| invalid_json | {null=} |",
606+
"+------------------+---------+",
607+
];
608+
609+
let batches = run_query(sql).await.unwrap();
610+
assert_batches_eq!(expected, &batches);
611+
}
612+
613+
#[tokio::test]
614+
async fn test_json_get_unnest_plan() {
615+
let sql = "explain select json_get(json_get(json_data, 'foo'), 0) v from test";
616+
let expected = [
617+
"Projection: json_get(test.json_data, Utf8(\"foo\"), Int64(0)) AS v",
618+
" TableScan: test projection=[json_data]",
619+
];
620+
621+
let plan_lines = logical_plan(sql).await;
622+
assert_eq!(plan_lines, expected);
623+
}
624+
625+
#[tokio::test]
626+
async fn test_json_get_int_unnest() {
627+
let sql = "select name, json_get(json_get(json_data, 'foo'), 0)::int v from test";
628+
629+
let expected = [
630+
"+------------------+---+",
631+
"| name | v |",
632+
"+------------------+---+",
633+
"| object_foo | |",
634+
"| object_foo_array | 1 |",
635+
"| object_foo_obj | |",
636+
"| object_foo_null | |",
637+
"| object_bar | |",
638+
"| list_foo | |",
639+
"| invalid_json | |",
640+
"+------------------+---+",
641+
];
642+
643+
let batches = run_query(sql).await.unwrap();
644+
assert_batches_eq!(expected, &batches);
645+
}
646+
647+
#[tokio::test]
648+
async fn test_json_get_int_unnest_plan() {
649+
let sql = "explain select json_get(json_get(json_data, 'foo'), 0)::int v from test";
650+
let expected = [
651+
"Projection: json_get_int(test.json_data, Utf8(\"foo\"), Int64(0)) AS v",
652+
" TableScan: test projection=[json_data]",
653+
];
654+
655+
let plan_lines = logical_plan(sql).await;
656+
assert_eq!(plan_lines, expected);
657+
}
658+
659+
#[tokio::test]
660+
async fn test_multiple_lookup_arrays() {
661+
let sql = "select json_get(json_data, str_key1, str_key2) v from more_nested";
662+
let err = run_query(sql).await.unwrap_err();
663+
assert_eq!(
664+
err.to_string(),
665+
"Execution error: More than 1 path element is not supported when querying JSON using an array."
666+
);
667+
}
668+
669+
#[tokio::test]
670+
async fn test_json_get_union_array_nested() {
671+
let sql = "select json_get(json_get(json_data, str_key1), str_key2) v from more_nested";
549672
let expected = [
550673
"+-------------+",
551674
"| v |",
@@ -556,13 +679,27 @@ async fn test_json_get_union_array_skip() {
556679
"+-------------+",
557680
];
558681

559-
let sql = "select json_get(json_get(json_data, str_key1), str_key2) v from more_nested";
560682
let batches = run_query(sql).await.unwrap();
561683
assert_batches_eq!(expected, &batches);
562684
}
563685

686+
#[tokio::test]
687+
async fn test_json_get_union_array_nested_plan() {
688+
let sql = "explain select json_get(json_get(json_data, str_key1), str_key2) v from more_nested";
689+
// json_get is not un-nested because lookup types are not literals
690+
let expected = [
691+
"Projection: json_get(json_get(more_nested.json_data, more_nested.str_key1), more_nested.str_key2) AS v",
692+
" TableScan: more_nested projection=[json_data, str_key1, str_key2]",
693+
];
694+
695+
let plan_lines = logical_plan(sql).await;
696+
assert_eq!(plan_lines, expected);
697+
}
698+
564699
#[tokio::test]
565700
async fn test_json_get_union_array_skip_double_nested() {
701+
let sql =
702+
"select json_data, json_get_int(json_get(json_get(json_data, str_key1), str_key2), int_key) v from more_nested";
566703
let expected = [
567704
"+--------------------------+---+",
568705
"| json_data | v |",
@@ -573,8 +710,6 @@ async fn test_json_get_union_array_skip_double_nested() {
573710
"+--------------------------+---+",
574711
];
575712

576-
let sql =
577-
"select json_data, json_get_int(json_get(json_get(json_data, str_key1), str_key2), int_key) v from more_nested";
578713
let batches = run_query(sql).await.unwrap();
579714
assert_batches_eq!(expected, &batches);
580715
}

tests/utils/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,10 @@ pub async fn display_val(batch: Vec<RecordBatch>) -> (DataType, String) {
142142
let repr = f.value(0).try_to_string().unwrap();
143143
(schema_col.data_type().clone(), repr)
144144
}
145+
146+
pub async fn logical_plan(sql: &str) -> Vec<String> {
147+
let batches = run_query(sql).await.unwrap();
148+
let plan_col = batches[0].column(1).as_any().downcast_ref::<StringArray>().unwrap();
149+
let logical_plan = plan_col.value(0);
150+
logical_plan.split('\n').map(|s| s.to_string()).collect()
151+
}

0 commit comments

Comments
 (0)