@@ -3,18 +3,20 @@ use std::sync::Arc;
3
3
4
4
use datafusion:: arrow:: array:: {
5
5
Array , ArrayAccessor , ArrayRef , AsArray , DictionaryArray , Int64Array , LargeStringArray , PrimitiveArray ,
6
- StringArray , StringViewArray , UInt64Array , UnionArray ,
6
+ StringArray , StringViewArray , UInt64Array ,
7
7
} ;
8
8
use datafusion:: arrow:: compute:: take;
9
9
use datafusion:: arrow:: datatypes:: {
10
- ArrowDictionaryKeyType , ArrowNativeType , ArrowPrimitiveType , DataType , Int64Type , UInt64Type ,
10
+ ArrowDictionaryKeyType , ArrowNativeType , ArrowNativeTypeOp , DataType , Int64Type , UInt64Type ,
11
11
} ;
12
12
use datafusion:: arrow:: downcast_dictionary_array;
13
13
use datafusion:: common:: { exec_err, plan_err, Result as DataFusionResult , ScalarValue } ;
14
14
use datafusion:: logical_expr:: ColumnarValue ;
15
15
use jiter:: { Jiter , JiterError , Peek } ;
16
16
17
- use crate :: common_union:: { is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL } ;
17
+ use crate :: common_union:: {
18
+ is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL ,
19
+ } ;
18
20
19
21
/// General implementation of `ScalarUDFImpl::return_type`.
20
22
///
@@ -95,6 +97,7 @@ impl From<i64> for JsonPath<'_> {
95
97
}
96
98
}
97
99
100
+ #[ derive( Debug ) ]
98
101
enum JsonPathArgs < ' a > {
99
102
Array ( & ' a ArrayRef ) ,
100
103
Scalars ( Vec < JsonPath < ' a > > ) ,
@@ -175,9 +178,48 @@ fn invoke_array_array<C: FromIterator<Option<I>> + 'static, I>(
175
178
) -> DataFusionResult < ArrayRef > {
176
179
downcast_dictionary_array ! (
177
180
json_array => {
178
- let values = invoke_array_array( json_array. values( ) , path_array, to_array, jiter_find, return_dict) ?;
179
- post_process_dict( json_array, values, return_dict)
180
- }
181
+ fn wrap_as_dictionary<K : ArrowDictionaryKeyType >( original: & DictionaryArray <K >, new_values: ArrayRef ) -> DictionaryArray <K > {
182
+ assert_eq!( original. keys( ) . len( ) , new_values. len( ) ) ;
183
+ let mut key = K :: Native :: ZERO ;
184
+ let key_range = std:: iter:: from_fn( move || {
185
+ let next = key;
186
+ key = key. add_checked( K :: Native :: ONE ) . expect( "keys exhausted" ) ;
187
+ Some ( next)
188
+ } ) . take( new_values. len( ) ) ;
189
+ let mut keys = PrimitiveArray :: <K >:: from_iter_values( key_range) ;
190
+ if is_json_union( new_values. data_type( ) ) {
191
+ // JSON union: post-process the array to set keys to null where the union member is null
192
+ let type_ids = new_values. as_union( ) . type_ids( ) ;
193
+ keys = mask_dictionary_keys( & keys, type_ids) ;
194
+ }
195
+ DictionaryArray :: <K >:: new( keys, new_values)
196
+ }
197
+
198
+ // TODO: in theory if path_array is _also_ a dictionary we could work out the unique key
199
+ // combinations and do less work, but this can be left as a future optimization
200
+ let output = match json_array. values( ) . data_type( ) {
201
+ DataType :: Utf8 => zip_apply( json_array. downcast_dict:: <StringArray >( ) . unwrap( ) , path_array, to_array, jiter_find) ,
202
+ DataType :: LargeUtf8 => zip_apply( json_array. downcast_dict:: <LargeStringArray >( ) . unwrap( ) , path_array, to_array, jiter_find) ,
203
+ DataType :: Utf8View => zip_apply( json_array. downcast_dict:: <StringViewArray >( ) . unwrap( ) , path_array, to_array, jiter_find) ,
204
+ other => if let Some ( child_array) = nested_json_array_ref( json_array. values( ) , is_object_lookup_array( path_array. data_type( ) ) ) {
205
+ // Horrible case: dict containing union as input with array for paths, figure
206
+ // out from the path type which union members we should access, repack the
207
+ // dictionary and then recurse.
208
+ //
209
+ // Use direct return because if return_dict applies, the recursion will handle it.
210
+ return invoke_array_array( & ( Arc :: new( json_array. with_values( child_array. clone( ) ) ) as _) , path_array, to_array, jiter_find, return_dict)
211
+ } else {
212
+ exec_err!( "unexpected json array type {:?}" , other)
213
+ }
214
+ } ?;
215
+
216
+ if return_dict {
217
+ // ensure return is a dictionary to satisfy the declaration above in return_type_check
218
+ Ok ( Arc :: new( wrap_as_dictionary( json_array, output) ) )
219
+ } else {
220
+ Ok ( output)
221
+ }
222
+ } ,
181
223
DataType :: Utf8 => zip_apply( json_array. as_string:: <i32 >( ) . iter( ) , path_array, to_array, jiter_find) ,
182
224
DataType :: LargeUtf8 => zip_apply( json_array. as_string:: <i64 >( ) . iter( ) , path_array, to_array, jiter_find) ,
183
225
DataType :: Utf8View => zip_apply( json_array. as_string_view( ) . iter( ) , path_array, to_array, jiter_find) ,
@@ -239,6 +281,7 @@ fn invoke_scalar_array<C: FromIterator<Option<I>> + 'static, I>(
239
281
to_array,
240
282
jiter_find,
241
283
)
284
+ // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
242
285
. map ( ColumnarValue :: Array )
243
286
}
244
287
@@ -250,6 +293,7 @@ fn invoke_scalar_scalars<I>(
250
293
) -> DataFusionResult < ColumnarValue > {
251
294
let s = extract_json_scalar ( scalar) ?;
252
295
let v = jiter_find ( s, path) . ok ( ) ;
296
+ // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
253
297
Ok ( ColumnarValue :: Scalar ( to_scalar ( v) ) )
254
298
}
255
299
@@ -321,7 +365,7 @@ fn post_process_dict<T: ArrowDictionaryKeyType>(
321
365
if return_dict {
322
366
if is_json_union ( result_values. data_type ( ) ) {
323
367
// JSON union: post-process the array to set keys to null where the union member is null
324
- let type_ids = result_values. as_any ( ) . downcast_ref :: < UnionArray > ( ) . unwrap ( ) . type_ids ( ) ;
368
+ let type_ids = result_values. as_union ( ) . type_ids ( ) ;
325
369
Ok ( Arc :: new ( DictionaryArray :: new (
326
370
mask_dictionary_keys ( dict_array. keys ( ) , type_ids) ,
327
371
result_values,
@@ -413,7 +457,7 @@ impl From<Utf8Error> for GetError {
413
457
///
414
458
/// That said, doing this might also be an optimization for cases like null-checking without needing
415
459
/// to check the value union array.
416
- fn mask_dictionary_keys < K : ArrowPrimitiveType > ( keys : & PrimitiveArray < K > , type_ids : & [ i8 ] ) -> PrimitiveArray < K > {
460
+ fn mask_dictionary_keys < K : ArrowDictionaryKeyType > ( keys : & PrimitiveArray < K > , type_ids : & [ i8 ] ) -> PrimitiveArray < K > {
417
461
let mut null_mask = vec ! [ true ; keys. len( ) ] ;
418
462
for ( i, k) in keys. iter ( ) . enumerate ( ) {
419
463
match k {
0 commit comments