@@ -19,8 +19,15 @@ use vortex_array::arrays::SharedArray;
1919use vortex_array:: dtype:: DType ;
2020use vortex_array:: dtype:: FieldMask ;
2121use vortex_array:: expr:: Expression ;
22+ use vortex_array:: expr:: is_root;
23+ use vortex_array:: expr:: label_is_fallible;
24+ use vortex_array:: expr:: label_null_sensitive;
2225use vortex_array:: expr:: root;
26+ use vortex_array:: expr:: traversal:: NodeExt ;
27+ use vortex_array:: expr:: traversal:: Transformed ;
28+ use vortex_array:: expr:: traversal:: TraversalOrder ;
2329use vortex_array:: optimizer:: ArrayOptimizer ;
30+ use vortex_array:: scalar_fn:: is_negative_cost;
2431use vortex_error:: VortexError ;
2532use vortex_error:: VortexExpect ;
2633use vortex_error:: VortexResult ;
@@ -100,10 +107,7 @@ impl DictReader {
100107 )
101108 . vortex_expect ( "must construct dict values array evaluation" )
102109 . map_err ( Arc :: new)
103- . map ( move |array| {
104- let array = array?;
105- Ok ( SharedArray :: new ( array) . into_array ( ) )
106- } )
110+ . map ( move |array| Ok ( SharedArray :: new ( array?) . into_array ( ) ) )
107111 . boxed ( )
108112 . shared ( )
109113 } )
@@ -155,6 +159,49 @@ impl DictReader {
155159 }
156160}
157161
162+ fn references_root ( expr : & Expression ) -> bool {
163+ is_root ( expr) || expr. children ( ) . iter ( ) . any ( references_root)
164+ }
165+
166+ /// Split expression into two parts:
167+ ///
168+ /// left is the optional outer part that we want to apply to array after
169+ /// canonicalizing.
170+ /// right is the optional inner part that we want to apply to array before
171+ /// canonicalizing.
172+ ///
173+ /// We want to push to array only if expression has a negative cost, is
174+ /// infallible and null-insensitive.
175+ fn split_expression_for_pushdown ( expr : Expression ) -> ( Option < Expression > , Option < Expression > ) {
176+ let labelled_expr = expr. clone ( ) ;
177+ let fallible = label_is_fallible ( & labelled_expr) ;
178+ let null_sensitive = label_null_sensitive ( & labelled_expr) ;
179+ let mut inner: Option < Expression > = None ;
180+
181+ let outer = expr
182+ . transform_down ( |node| {
183+ if is_negative_cost ( node. id ( ) )
184+ && references_root ( & node)
185+ && !fallible. get ( & node) . copied ( ) . unwrap_or ( true )
186+ && !null_sensitive. get ( & node) . copied ( ) . unwrap_or ( true )
187+ {
188+ inner = Some ( node) ;
189+ Ok ( Transformed {
190+ value : root ( ) ,
191+ changed : true ,
192+ order : TraversalOrder :: Skip ,
193+ } )
194+ } else {
195+ Ok ( Transformed :: no ( node) )
196+ }
197+ } )
198+ . vortex_expect ( "infallible" )
199+ . into_inner ( ) ;
200+
201+ let outer = ( !is_root ( & outer) ) . then_some ( outer) ;
202+ ( outer, inner)
203+ }
204+
158205impl LayoutReader for DictReader {
159206 fn name ( & self ) -> & Arc < str > {
160207 & self . name
@@ -229,13 +276,18 @@ impl LayoutReader for DictReader {
229276 mask : MaskFuture ,
230277 ) -> VortexResult < BoxFuture < ' static , VortexResult < ArrayRef > > > {
231278 // TODO: fix up expr partitioning with fallible & null sensitive annotations
232- let values_eval = self . values_array ( ) ;
233279 let codes_eval = self
234280 . codes
235281 . projection_evaluation ( row_range, & root ( ) , mask)
236282 . map_err ( |err| err. with_context ( "While evaluating projection on codes" ) ) ?;
237- let expr = expr. clone ( ) ;
238283
284+ let ( expr_outer, expr_inner) = split_expression_for_pushdown ( expr. clone ( ) ) ;
285+
286+ let values_eval = if let Some ( inner) = expr_inner {
287+ self . values_eval ( inner)
288+ } else {
289+ self . values_array ( )
290+ } ;
239291 let all_values_referenced = self . layout . has_all_values_referenced ( ) ;
240292 Ok ( async move {
241293 let ( values, codes) = try_join ! ( values_eval. map_err( VortexError :: from) , codes_eval) ?;
@@ -252,7 +304,11 @@ impl LayoutReader for DictReader {
252304 . into_array ( )
253305 . optimize ( ) ?;
254306
255- array. apply ( & expr)
307+ if let Some ( expr) = expr_outer {
308+ array. apply ( & expr)
309+ } else {
310+ Ok ( array)
311+ }
256312 }
257313 . boxed ( ) )
258314 }
@@ -281,11 +337,20 @@ mod tests {
281337 use vortex_array:: dtype:: FieldName ;
282338 use vortex_array:: dtype:: FieldNames ;
283339 use vortex_array:: dtype:: Nullability ;
340+ use vortex_array:: dtype:: PType ;
341+ use vortex_array:: expr:: Expression ;
342+ use vortex_array:: expr:: byte_length;
343+ use vortex_array:: expr:: cast;
284344 use vortex_array:: expr:: eq;
285345 use vortex_array:: expr:: is_not_null;
346+ use vortex_array:: expr:: is_root;
347+ use vortex_array:: expr:: like;
286348 use vortex_array:: expr:: lit;
287349 use vortex_array:: expr:: pack;
288350 use vortex_array:: expr:: root;
351+ use vortex_array:: expr:: traversal:: NodeExt ;
352+ use vortex_array:: expr:: traversal:: Transformed ;
353+ use vortex_array:: expr:: traversal:: TraversalOrder ;
289354 use vortex_array:: scalar_fn:: session:: ScalarFnSession ;
290355 use vortex_array:: session:: ArraySession ;
291356 use vortex_array:: validity:: Validity ;
@@ -296,6 +361,7 @@ mod tests {
296361 use vortex_io:: session:: RuntimeSessionExt ;
297362 use vortex_session:: VortexSession ;
298363
364+ use super :: split_expression_for_pushdown;
299365 use crate :: LayoutId ;
300366 use crate :: LayoutRef ;
301367 use crate :: LayoutStrategy ;
@@ -542,4 +608,61 @@ mod tests {
542608 assert_arrays_eq ! ( actual_canonical, expected) ;
543609 } )
544610 }
611+
612+ fn join_split_expr ( initial : & Expression , outer : Option < Expression > , inner : Option < Expression > ) {
613+ let outer_expr = outer. unwrap_or_else ( root) ;
614+ let inner_expr = inner. unwrap_or_else ( root) ;
615+ let expected = outer_expr
616+ . transform_down ( |node| {
617+ if !is_root ( & node) {
618+ return Ok ( Transformed :: no ( node) ) ;
619+ }
620+ Ok ( Transformed {
621+ value : inner_expr. clone ( ) ,
622+ changed : true ,
623+ order : TraversalOrder :: Skip ,
624+ } )
625+ } )
626+ . vortex_expect ( "infallible" ) ;
627+ assert_eq ! ( & expected. into_inner( ) , initial) ;
628+ }
629+
630+ #[ test]
631+ fn split_expr_cast_root ( ) {
632+ let ( outer, inner) = split_expression_for_pushdown ( root ( ) ) ;
633+ assert_eq ! ( outer, None ) ;
634+ assert_eq ! ( inner, None ) ; // Applying root to array is useless work
635+ }
636+
637+ #[ test]
638+ fn split_expr_partial_pushdown ( ) {
639+ let dtype = DType :: Primitive ( PType :: U64 , Nullability :: NonNullable ) ;
640+ let expr = cast ( byte_length ( root ( ) ) , dtype. clone ( ) ) ;
641+ let ( outer, inner) = split_expression_for_pushdown ( expr. clone ( ) ) ;
642+ // [0] = cast([1], dtype)
643+ // [1] = byte_length(root)
644+ assert_eq ! ( outer, Some ( cast( root( ) , dtype) ) ) ;
645+ assert_eq ! ( inner, Some ( byte_length( root( ) ) ) ) ;
646+ join_split_expr ( & expr, outer, inner) ;
647+ }
648+
649+ #[ test]
650+ fn split_expr_full_pushdown ( ) {
651+ let expr = byte_length ( root ( ) ) ;
652+ let ( outer, inner) = split_expression_for_pushdown ( expr. clone ( ) ) ;
653+ assert_eq ! ( outer, None ) ;
654+ assert_eq ! ( inner, Some ( byte_length( root( ) ) ) ) ;
655+ join_split_expr ( & expr, outer, inner) ;
656+ }
657+
658+ #[ test]
659+ fn split_expr_no_pushdown ( ) {
660+ // We can push down lit(), but it we replace
661+ // lit() with root(), the semantics change.
662+ let expr = like ( root ( ) , lit ( 1u64 ) ) ;
663+ let ( outer, inner) = split_expression_for_pushdown ( expr. clone ( ) ) ;
664+ assert_eq ! ( outer, Some ( expr. clone( ) ) ) ;
665+ assert_eq ! ( inner, None ) ;
666+ join_split_expr ( & expr, outer, inner) ;
667+ }
545668}
0 commit comments