11use std:: sync:: Arc ;
22
3- use arrow:: array:: { Array , cast:: AsArray } ;
3+ #[ cfg( test) ]
4+ use arrow:: array:: StructArray ;
5+ use arrow:: array:: { Array , ArrayRef , cast:: AsArray } ;
6+ #[ cfg( test) ]
7+ use arrow_schema:: Fields ;
48use arrow_schema:: extension:: ExtensionType ;
59use arrow_schema:: { DataType , Field } ;
610use datafusion:: common:: exec_datafusion_err;
711use datafusion:: error:: Result ;
12+ use datafusion:: logical_expr:: { ColumnarValue , ScalarFunctionArgs } ;
813use datafusion:: { common:: exec_err, scalar:: ScalarValue } ;
14+ use parquet_variant:: { Variant , VariantPath } ;
915use parquet_variant_compute:: { VariantArray , VariantType } ;
1016
1117#[ cfg( test) ]
@@ -118,6 +124,129 @@ pub fn try_parse_string_columnar(array: &Arc<dyn Array>) -> Result<Vec<Option<&s
118124 Err ( exec_datafusion_err ! ( "expected string array" ) )
119125}
120126
127+ pub fn variant_get_single_value < T > (
128+ variant_array : & VariantArray ,
129+ index : usize ,
130+ path : & str ,
131+ extract : for <' m , ' v > fn ( Variant < ' m , ' v > ) -> Result < Option < T > > ,
132+ ) -> Result < Option < T > > {
133+ let Some ( variant) = variant_array. iter ( ) . nth ( index) . flatten ( ) else {
134+ return Ok ( None ) ;
135+ } ;
136+
137+ let variant_path = VariantPath :: from ( path) ;
138+ let Some ( value) = variant. get_path ( & variant_path) else {
139+ return Ok ( None ) ;
140+ } ;
141+
142+ extract ( value)
143+ }
144+
145+ pub fn variant_get_array_values < T > (
146+ variant_array : & VariantArray ,
147+ path : & str ,
148+ extract : for <' m , ' v > fn ( Variant < ' m , ' v > ) -> Result < Option < T > > ,
149+ ) -> Result < Vec < Option < T > > > {
150+ let variant_path = VariantPath :: from ( path) ;
151+
152+ variant_array
153+ . iter ( )
154+ . map ( |maybe_variant| {
155+ let Some ( variant) = maybe_variant else {
156+ return Ok ( None ) ;
157+ } ;
158+
159+ let Some ( value) = variant. get_path ( & variant_path) else {
160+ return Ok ( None ) ;
161+ } ;
162+
163+ extract ( value)
164+ } )
165+ . collect ( )
166+ }
167+
168+ pub fn invoke_variant_get_typed < T > (
169+ args : ScalarFunctionArgs ,
170+ scalar_from_option : fn ( Option < T > ) -> ScalarValue ,
171+ array_from_values : fn ( Vec < Option < T > > ) -> ArrayRef ,
172+ extract : for <' m , ' v > fn ( Variant < ' m , ' v > ) -> Result < Option < T > > ,
173+ ) -> Result < ColumnarValue > {
174+ let ( variant_arg, path_arg) = match args. args . as_slice ( ) {
175+ [ variant_arg, path_arg] => ( variant_arg, path_arg) ,
176+ _ => return exec_err ! ( "expected 2 arguments" ) ,
177+ } ;
178+
179+ let variant_field = args
180+ . arg_fields
181+ . first ( )
182+ . ok_or_else ( || exec_datafusion_err ! ( "expected argument field" ) ) ?;
183+
184+ try_field_as_variant_array ( variant_field. as_ref ( ) ) ?;
185+
186+ let out = match ( variant_arg, path_arg) {
187+ ( ColumnarValue :: Array ( variant_array) , ColumnarValue :: Scalar ( path_scalar) ) => {
188+ let path = try_parse_string_scalar ( path_scalar) ?
189+ . map ( |s| s. as_str ( ) )
190+ . unwrap_or_default ( ) ;
191+
192+ let variant_array = VariantArray :: try_new ( variant_array. as_ref ( ) ) ?;
193+ let values = variant_get_array_values ( & variant_array, path, extract) ?;
194+ ColumnarValue :: Array ( array_from_values ( values) )
195+ }
196+ ( ColumnarValue :: Scalar ( scalar_variant) , ColumnarValue :: Scalar ( path_scalar) ) => {
197+ let ScalarValue :: Struct ( variant_array) = scalar_variant else {
198+ return exec_err ! ( "expected struct array" ) ;
199+ } ;
200+
201+ let path = try_parse_string_scalar ( path_scalar) ?
202+ . map ( |s| s. as_str ( ) )
203+ . unwrap_or_default ( ) ;
204+
205+ let variant_array = VariantArray :: try_new ( variant_array. as_ref ( ) ) ?;
206+ let value = variant_get_single_value ( & variant_array, 0 , path, extract) ?;
207+
208+ ColumnarValue :: Scalar ( scalar_from_option ( value) )
209+ }
210+ ( ColumnarValue :: Array ( variant_array) , ColumnarValue :: Array ( paths) ) => {
211+ if variant_array. len ( ) != paths. len ( ) {
212+ return exec_err ! ( "expected variant array and paths to be of same length" ) ;
213+ }
214+
215+ let paths = try_parse_string_columnar ( paths) ?;
216+ let variant_array = VariantArray :: try_new ( variant_array. as_ref ( ) ) ?;
217+
218+ let values: Vec < Option < T > > = ( 0 ..variant_array. len ( ) )
219+ . map ( |i| {
220+ let path = paths[ i] . unwrap_or_default ( ) ;
221+ variant_get_single_value ( & variant_array, i, path, extract)
222+ } )
223+ . collect :: < Result < _ > > ( ) ?;
224+
225+ ColumnarValue :: Array ( array_from_values ( values) )
226+ }
227+ ( ColumnarValue :: Scalar ( scalar_variant) , ColumnarValue :: Array ( paths) ) => {
228+ let ScalarValue :: Struct ( variant_array) = scalar_variant else {
229+ return exec_err ! ( "expected struct array" ) ;
230+ } ;
231+
232+ let variant_array = VariantArray :: try_new ( variant_array. as_ref ( ) ) ?;
233+ let paths = try_parse_string_columnar ( paths) ?;
234+
235+ let values: Vec < Option < T > > = paths
236+ . iter ( )
237+ . map ( |path| {
238+ let path = path. unwrap_or_default ( ) ;
239+ variant_get_single_value ( & variant_array, 0 , path, extract)
240+ } )
241+ . collect :: < Result < _ > > ( ) ?;
242+
243+ ColumnarValue :: Array ( array_from_values ( values) )
244+ }
245+ } ;
246+
247+ Ok ( out)
248+ }
249+
121250/// This is similar to anyhow's ensure! macro
122251/// If the `pred` fails, it will return a DataFusionError
123252pub fn ensure ( pred : bool , err_msg : & str ) -> Result < ( ) > {
@@ -139,6 +268,50 @@ pub fn build_variant_array_from_json(value: &serde_json::Value) -> VariantArray
139268 builder. build ( )
140269}
141270
271+ #[ cfg( test) ]
272+ pub fn variant_scalar_from_json ( json : serde_json:: Value ) -> ScalarValue {
273+ let mut builder = VariantArrayBuilder :: new ( 1 ) ;
274+ builder. append_json ( json. to_string ( ) . as_str ( ) ) . unwrap ( ) ;
275+ ScalarValue :: Struct ( Arc :: new ( builder. build ( ) . into ( ) ) )
276+ }
277+
278+ #[ cfg( test) ]
279+ pub fn variant_array_from_json_rows ( json_rows : & [ serde_json:: Value ] ) -> ArrayRef {
280+ let mut builder = VariantArrayBuilder :: new ( json_rows. len ( ) ) ;
281+ for value in json_rows {
282+ builder. append_json ( value. to_string ( ) . as_str ( ) ) . unwrap ( ) ;
283+ }
284+ let variant_array: StructArray = builder. build ( ) . into ( ) ;
285+ Arc :: new ( variant_array) as ArrayRef
286+ }
287+
288+ #[ cfg( test) ]
289+ pub fn standard_variant_get_arg_fields ( ) -> Vec < Arc < Field > > {
290+ vec ! [
291+ Arc :: new(
292+ Field :: new( "input" , DataType :: Struct ( Fields :: empty( ) ) , true )
293+ . with_extension_type( VariantType ) ,
294+ ) ,
295+ Arc :: new( Field :: new( "path" , DataType :: Utf8 , true ) ) ,
296+ ]
297+ }
298+
299+ #[ cfg( test) ]
300+ pub fn build_variant_get_args (
301+ variant_input : ColumnarValue ,
302+ path : ColumnarValue ,
303+ return_data_type : DataType ,
304+ arg_fields : Vec < Arc < Field > > ,
305+ ) -> ScalarFunctionArgs {
306+ ScalarFunctionArgs {
307+ args : vec ! [ variant_input, path] ,
308+ return_field : Arc :: new ( Field :: new ( "result" , return_data_type, true ) ) ,
309+ arg_fields,
310+ number_rows : Default :: default ( ) ,
311+ config_options : Default :: default ( ) ,
312+ }
313+ }
314+
142315#[ cfg( test) ]
143316#[ allow( unused) ]
144317pub fn build_variant_array_from_json_array ( jsons : & [ Option < serde_json:: Value > ] ) -> VariantArray {
0 commit comments