Skip to content

Commit 531e9ca

Browse files
committed
feat: implement json_get_array UDF
1 parent 4ad1e9b commit 531e9ca

File tree

3 files changed

+160
-0
lines changed

3 files changed

+160
-0
lines changed

src/json_get_array.rs

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
use std::any::Any;
2+
use std::sync::Arc;
3+
4+
use arrow::array::{GenericListArray, ListBuilder, StringBuilder};
5+
use arrow_schema::{DataType, Field};
6+
use datafusion_common::arrow::array::ArrayRef;
7+
use datafusion_common::{Result as DataFusionResult, ScalarValue};
8+
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
9+
use jiter::Peek;
10+
11+
use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath};
12+
use crate::common_macros::make_udf_function;
13+
14+
struct StrArrayColumn {
15+
rows: GenericListArray<i32>,
16+
}
17+
18+
impl FromIterator<Option<Vec<String>>> for StrArrayColumn {
19+
fn from_iter<T: IntoIterator<Item = Option<Vec<String>>>>(iter: T) -> Self {
20+
let string_builder = StringBuilder::new();
21+
let mut list_builder = ListBuilder::new(string_builder);
22+
23+
for row in iter {
24+
if let Some(row) = row {
25+
for elem in row {
26+
list_builder.values().append_value(elem);
27+
}
28+
29+
list_builder.append(true);
30+
} else {
31+
list_builder.append(false);
32+
}
33+
}
34+
35+
Self {
36+
rows: list_builder.finish(),
37+
}
38+
}
39+
}
40+
41+
make_udf_function!(
42+
JsonGetArray,
43+
json_get_array,
44+
json_data path,
45+
r#"Get an arrow array value from a JSON string by its "path""#
46+
);
47+
48+
#[derive(Debug)]
49+
pub(super) struct JsonGetArray {
50+
signature: Signature,
51+
aliases: [String; 1],
52+
}
53+
54+
impl Default for JsonGetArray {
55+
fn default() -> Self {
56+
Self {
57+
signature: Signature::variadic_any(Volatility::Immutable),
58+
aliases: ["json_get_array".to_string()],
59+
}
60+
}
61+
}
62+
63+
impl ScalarUDFImpl for JsonGetArray {
64+
fn as_any(&self) -> &dyn Any {
65+
self
66+
}
67+
68+
fn name(&self) -> &str {
69+
self.aliases[0].as_str()
70+
}
71+
72+
fn signature(&self) -> &Signature {
73+
&self.signature
74+
}
75+
76+
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
77+
check_args(arg_types, self.name()).map(|()| DataType::List(Field::new("item", DataType::Utf8, true).into()))
78+
}
79+
80+
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
81+
invoke::<StrArrayColumn, Vec<String>>(
82+
args,
83+
jiter_json_get_array,
84+
|c| Ok(Arc::new(c.rows) as ArrayRef),
85+
|i| {
86+
let string_builder = StringBuilder::new();
87+
let mut list_builder = ListBuilder::new(string_builder);
88+
89+
if let Some(row) = i {
90+
for elem in row {
91+
list_builder.values().append_value(elem);
92+
}
93+
}
94+
95+
ScalarValue::List(list_builder.finish().into())
96+
},
97+
)
98+
}
99+
100+
fn aliases(&self) -> &[String] {
101+
&self.aliases
102+
}
103+
}
104+
105+
fn jiter_json_get_array(json_data: Option<&str>, path: &[JsonPath]) -> Result<Vec<String>, GetError> {
106+
if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) {
107+
match peek {
108+
Peek::Array => {
109+
let mut peek_opt = jiter.known_array()?;
110+
let mut array_values = Vec::new();
111+
112+
while let Some(peek) = peek_opt {
113+
let start = jiter.current_index();
114+
jiter.known_skip(peek)?;
115+
let object_slice = jiter.slice_to_current(start);
116+
let object_string = std::str::from_utf8(object_slice)?;
117+
118+
array_values.push(object_string.to_owned());
119+
120+
peek_opt = jiter.array_step()?;
121+
}
122+
123+
Ok(array_values)
124+
}
125+
_ => get_err!(),
126+
}
127+
} else {
128+
get_err!()
129+
}
130+
}

src/lib.rs

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ mod common_union;
1010
mod json_as_text;
1111
mod json_contains;
1212
mod json_get;
13+
mod json_get_array;
1314
mod json_get_bool;
1415
mod json_get_float;
1516
mod json_get_int;
@@ -22,6 +23,7 @@ pub mod functions {
2223
pub use crate::json_as_text::json_as_text;
2324
pub use crate::json_contains::json_contains;
2425
pub use crate::json_get::json_get;
26+
pub use crate::json_get_array::json_get_array;
2527
pub use crate::json_get_bool::json_get_bool;
2628
pub use crate::json_get_float::json_get_float;
2729
pub use crate::json_get_int::json_get_int;
@@ -34,6 +36,7 @@ pub mod udfs {
3436
pub use crate::json_as_text::json_as_text_udf;
3537
pub use crate::json_contains::json_contains_udf;
3638
pub use crate::json_get::json_get_udf;
39+
pub use crate::json_get_array::json_get_array_udf;
3740
pub use crate::json_get_bool::json_get_bool_udf;
3841
pub use crate::json_get_float::json_get_float_udf;
3942
pub use crate::json_get_int::json_get_int_udf;
@@ -54,6 +57,7 @@ pub mod udfs {
5457
pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
5558
let functions: Vec<Arc<ScalarUDF>> = vec![
5659
json_get::json_get_udf(),
60+
json_get_array::json_get_array_udf(),
5761
json_get_bool::json_get_bool_udf(),
5862
json_get_float::json_get_float_udf(),
5963
json_get_int::json_get_int_udf(),

tests/main.rs

+26
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,31 @@ async fn test_json_length_vec() {
476476
assert_batches_eq!(expected, &batches);
477477
}
478478

479+
#[tokio::test]
480+
async fn test_json_get_arrow_array() {
481+
let sql = r#"select name, json_get_array(json_data, 'foo') from test"#;
482+
let batches = run_query(sql).await.unwrap();
483+
484+
let expected = [
485+
"+------------------+--------------------------------------------+",
486+
"| name | json_get_array(test.json_data,Utf8(\"foo\")) |",
487+
"+------------------+--------------------------------------------+",
488+
"| object_foo | |",
489+
"| object_foo_array | [1] |",
490+
"| object_foo_obj | |",
491+
"| object_foo_null | |",
492+
"| object_bar | |",
493+
"| list_foo | |",
494+
"| invalid_json | |",
495+
"+------------------+--------------------------------------------+",
496+
];
497+
498+
assert_batches_eq!(expected, &batches);
499+
500+
let batches = run_query_large(sql).await.unwrap();
501+
assert_batches_eq!(expected, &batches);
502+
}
503+
479504
#[tokio::test]
480505
async fn test_no_args() {
481506
let err = run_query(r#"select json_len()"#).await.unwrap_err();
@@ -1131,6 +1156,7 @@ async fn test_long_arrow_cast() {
11311156
assert_batches_eq!(expected, &batches);
11321157
}
11331158

1159+
#[tokio::test]
11341160
async fn test_arrow_cast_numeric() {
11351161
let sql = r#"select ('{"foo": 420}'->'foo')::numeric = 420"#;
11361162
let batches = run_query(sql).await.unwrap();

0 commit comments

Comments
 (0)