@@ -5,7 +5,8 @@ use datafusion::arrow::array::{
55 Array , ArrayRef , AsArray , DictionaryArray , Int64Array , LargeStringArray , PrimitiveArray , StringArray ,
66 StringViewArray , UInt64Array , UnionArray ,
77} ;
8- use datafusion:: arrow:: datatypes:: { ArrowNativeType , ArrowPrimitiveType , DataType } ;
8+ use datafusion:: arrow:: compute:: take;
9+ use datafusion:: arrow:: datatypes:: { ArrowDictionaryKeyType , ArrowNativeType , ArrowPrimitiveType , DataType } ;
910use datafusion:: arrow:: downcast_dictionary_array;
1011use datafusion:: common:: { exec_err, plan_err, Result as DataFusionResult , ScalarValue } ;
1112use datafusion:: logical_expr:: ColumnarValue ;
@@ -21,7 +22,7 @@ use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_arr
2122/// * `fn_name` - The name of the function
2223/// * `value_type` - The general return type of the function, might be wrapped in a dictionary depending
2324/// on the first argument
24- pub fn scalar_udf_return_type ( args : & [ DataType ] , fn_name : & str , value_type : DataType ) -> DataFusionResult < DataType > {
25+ pub fn return_type_check ( args : & [ DataType ] , fn_name : & str , value_type : DataType ) -> DataFusionResult < DataType > {
2526 let Some ( first) = args. first ( ) else {
2627 return plan_err ! ( "The '{fn_name}' function requires one or more arguments." ) ;
2728 } ;
@@ -105,6 +106,7 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
105106 jiter_find : impl Fn ( Option < & str > , & [ JsonPath ] ) -> Result < I , GetError > ,
106107 to_array : impl Fn ( C ) -> DataFusionResult < ArrayRef > ,
107108 to_scalar : impl Fn ( Option < I > ) -> ScalarValue ,
109+ return_dict : bool ,
108110) -> DataFusionResult < ColumnarValue > {
109111 let Some ( first_arg) = args. first ( ) else {
110112 // I think this can't happen, but I assumed the same about args[1] and I was wrong, so better to be safe
@@ -118,13 +120,17 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
118120 // TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23
119121 exec_err ! ( "More than 1 path element is not supported when querying JSON using an array." )
120122 } else {
121- invoke_array ( json_array, a, to_array, jiter_find)
123+ invoke_array ( json_array, a, to_array, jiter_find, return_dict )
122124 }
123125 }
124- Some ( ColumnarValue :: Scalar ( _) ) => {
125- scalar_apply ( json_array, & JsonPath :: extract_path ( args) , to_array, jiter_find)
126- }
127- None => scalar_apply ( json_array, & [ ] , to_array, jiter_find) ,
126+ Some ( ColumnarValue :: Scalar ( _) ) => scalar_apply (
127+ json_array,
128+ & JsonPath :: extract_path ( args) ,
129+ to_array,
130+ jiter_find,
131+ return_dict,
132+ ) ,
133+ None => scalar_apply ( json_array, & [ ] , to_array, jiter_find, return_dict) ,
128134 } ;
129135 array. map ( ColumnarValue :: from)
130136 }
@@ -137,24 +143,26 @@ fn invoke_array<C: FromIterator<Option<I>> + 'static, I>(
137143 needle_array : & ArrayRef ,
138144 to_array : impl Fn ( C ) -> DataFusionResult < ArrayRef > ,
139145 jiter_find : impl Fn ( Option < & str > , & [ JsonPath ] ) -> Result < I , GetError > ,
146+ return_dict : bool ,
140147) -> DataFusionResult < ArrayRef > {
141148 if let Some ( d) = needle_array. as_any_dictionary_opt ( ) {
142- invoke_array ( json_array, d. values ( ) , to_array, jiter_find)
149+ // this is the (very rare) case where the needle is a dictionary, it shouldn't affect what we return
150+ invoke_array ( json_array, d. values ( ) , to_array, jiter_find, return_dict)
143151 } else if let Some ( str_path_array) = needle_array. as_any ( ) . downcast_ref :: < StringArray > ( ) {
144152 let paths = str_path_array. iter ( ) . map ( |opt_key| opt_key. map ( JsonPath :: Key ) ) ;
145- zip_apply ( json_array, paths, to_array, jiter_find, true )
153+ zip_apply ( json_array, paths, to_array, jiter_find, true , return_dict )
146154 } else if let Some ( str_path_array) = needle_array. as_any ( ) . downcast_ref :: < LargeStringArray > ( ) {
147155 let paths = str_path_array. iter ( ) . map ( |opt_key| opt_key. map ( JsonPath :: Key ) ) ;
148- zip_apply ( json_array, paths, to_array, jiter_find, true )
156+ zip_apply ( json_array, paths, to_array, jiter_find, true , return_dict )
149157 } else if let Some ( str_path_array) = needle_array. as_any ( ) . downcast_ref :: < StringViewArray > ( ) {
150158 let paths = str_path_array. iter ( ) . map ( |opt_key| opt_key. map ( JsonPath :: Key ) ) ;
151- zip_apply ( json_array, paths, to_array, jiter_find, true )
159+ zip_apply ( json_array, paths, to_array, jiter_find, true , return_dict )
152160 } else if let Some ( int_path_array) = needle_array. as_any ( ) . downcast_ref :: < Int64Array > ( ) {
153161 let paths = int_path_array. iter ( ) . map ( |opt_index| opt_index. map ( Into :: into) ) ;
154- zip_apply ( json_array, paths, to_array, jiter_find, false )
162+ zip_apply ( json_array, paths, to_array, jiter_find, false , return_dict )
155163 } else if let Some ( int_path_array) = needle_array. as_any ( ) . downcast_ref :: < UInt64Array > ( ) {
156164 let paths = int_path_array. iter ( ) . map ( |opt_index| opt_index. map ( Into :: into) ) ;
157- zip_apply ( json_array, paths, to_array, jiter_find, false )
165+ zip_apply ( json_array, paths, to_array, jiter_find, false , return_dict )
158166 } else {
159167 exec_err ! ( "unexpected second argument type, expected string or int array" )
160168 }
@@ -166,22 +174,15 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
166174 to_array : impl Fn ( C ) -> DataFusionResult < ArrayRef > ,
167175 jiter_find : impl Fn ( Option < & str > , & [ JsonPath ] ) -> Result < I , GetError > ,
168176 object_lookup : bool ,
177+ return_dict : bool ,
169178) -> DataFusionResult < ArrayRef > {
170179 // arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
171180 use datafusion:: arrow:: datatypes as arrow_schema;
172181
173182 let c = downcast_dictionary_array ! (
174183 json_array => {
175- let values = zip_apply( json_array. values( ) , path_array, to_array, jiter_find, object_lookup) ?;
176- if !is_json_union( values. data_type( ) ) {
177- return Ok ( Arc :: new( json_array. with_values( values) ) ) ;
178- }
179- // JSON union: post-process the array to set keys to null where the union member is null
180- let type_ids = values. as_any( ) . downcast_ref:: <UnionArray >( ) . unwrap( ) . type_ids( ) ;
181- return Ok ( Arc :: new( DictionaryArray :: new(
182- mask_dictionary_keys( json_array. keys( ) , type_ids) ,
183- values,
184- ) ) ) ;
184+ let values = zip_apply( json_array. values( ) , path_array, to_array, jiter_find, object_lookup, false ) ?;
185+ return post_process_dict( json_array, values, return_dict) ;
185186 }
186187 DataType :: Utf8 => zip_apply_iter( json_array. as_string:: <i32 >( ) . iter( ) , path_array, jiter_find) ,
187188 DataType :: LargeUtf8 => zip_apply_iter( json_array. as_string:: <i64 >( ) . iter( ) , path_array, jiter_find) ,
@@ -242,22 +243,15 @@ fn scalar_apply<C: FromIterator<Option<I>>, I>(
242243 path : & [ JsonPath ] ,
243244 to_array : impl Fn ( C ) -> DataFusionResult < ArrayRef > ,
244245 jiter_find : impl Fn ( Option < & str > , & [ JsonPath ] ) -> Result < I , GetError > ,
246+ return_dict : bool ,
245247) -> DataFusionResult < ArrayRef > {
246248 // arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
247249 use datafusion:: arrow:: datatypes as arrow_schema;
248250
249251 let c = downcast_dictionary_array ! (
250252 json_array => {
251- let values = scalar_apply( json_array. values( ) , path, to_array, jiter_find) ?;
252- if !is_json_union( values. data_type( ) ) {
253- return Ok ( Arc :: new( json_array. with_values( values) ) ) ;
254- }
255- // JSON union: post-process the array to set keys to null where the union member is null
256- let type_ids = values. as_any( ) . downcast_ref:: <UnionArray >( ) . unwrap( ) . type_ids( ) ;
257- return Ok ( Arc :: new( DictionaryArray :: new(
258- mask_dictionary_keys( json_array. keys( ) , type_ids) ,
259- values,
260- ) ) ) ;
253+ let values = scalar_apply( json_array. values( ) , path, to_array, jiter_find, false ) ?;
254+ return post_process_dict( json_array, values, return_dict) ;
261255 }
262256 DataType :: Utf8 => scalar_apply_iter( json_array. as_string:: <i32 >( ) . iter( ) , path, jiter_find) ,
263257 DataType :: LargeUtf8 => scalar_apply_iter( json_array. as_string:: <i64 >( ) . iter( ) , path, jiter_find) ,
@@ -268,10 +262,32 @@ fn scalar_apply<C: FromIterator<Option<I>>, I>(
268262 return exec_err!( "unexpected json array type {:?}" , other) ;
269263 }
270264 ) ;
271-
272265 to_array ( c)
273266}
274267
268+ /// Take a dictionary array of JSON data and an array of result values and combine them.
269+ fn post_process_dict < T : ArrowDictionaryKeyType > (
270+ dict_array : & DictionaryArray < T > ,
271+ result_values : ArrayRef ,
272+ return_dict : bool ,
273+ ) -> DataFusionResult < ArrayRef > {
274+ if return_dict {
275+ if is_json_union ( result_values. data_type ( ) ) {
276+ // JSON union: post-process the array to set keys to null where the union member is null
277+ let type_ids = result_values. as_any ( ) . downcast_ref :: < UnionArray > ( ) . unwrap ( ) . type_ids ( ) ;
278+ Ok ( Arc :: new ( DictionaryArray :: new (
279+ mask_dictionary_keys ( dict_array. keys ( ) , type_ids) ,
280+ result_values,
281+ ) ) )
282+ } else {
283+ Ok ( Arc :: new ( dict_array. with_values ( result_values) ) )
284+ }
285+ } else {
286+ // this is what cast would do under the hood to unpack a dictionary into an array of its values
287+ Ok ( take ( & result_values, dict_array. keys ( ) , None ) ?)
288+ }
289+ }
290+
275291fn is_object_lookup ( path : & [ JsonPath ] ) -> bool {
276292 if let Some ( first) = path. first ( ) {
277293 matches ! ( first, JsonPath :: Key ( _) )
0 commit comments