Skip to content

Commit 9099587

Browse files
committed
Pushdown some expressions to Dict layout reader
1 parent 583cbed commit 9099587

2 files changed

Lines changed: 145 additions & 7 deletions

File tree

vortex-array/src/scalar_fn/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,18 @@ mod sealed {
4848
/// This can be the **only** implementor for [`super::typed::DynScalarFn`].
4949
impl<V: ScalarFnVTable> Sealed for TypedScalarFnInstance<V> {}
5050
}
51+
52+
/*
53+
* A scalar function has a negative cost if applying it to an array and
54+
* canonicalizing is cheaper than canonicalizing an array and applying it.
55+
*
56+
* Example of negative cost expressions are byte_length() and get_item() since
57+
* they don't depend on input size.
58+
*
59+
* Example of non-negative cost expression is like()
60+
*/
61+
pub fn is_negative_cost(id: ScalarFnId) -> bool {
62+
id == Id::new_static("vortex.byte_length")
63+
|| id == Id::new_static("vortex.get_item")
64+
|| id == Id::new_static("vortex.literal")
65+
}

vortex-layout/src/layouts/dict/reader.rs

Lines changed: 130 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,15 @@ use vortex_array::arrays::SharedArray;
1919
use vortex_array::dtype::DType;
2020
use vortex_array::dtype::FieldMask;
2121
use vortex_array::expr::Expression;
22+
use vortex_array::expr::is_root;
23+
use vortex_array::expr::label_is_fallible;
24+
use vortex_array::expr::label_null_sensitive;
2225
use vortex_array::expr::root;
26+
use vortex_array::expr::traversal::NodeExt;
27+
use vortex_array::expr::traversal::Transformed;
28+
use vortex_array::expr::traversal::TraversalOrder;
2329
use vortex_array::optimizer::ArrayOptimizer;
30+
use vortex_array::scalar_fn::is_negative_cost;
2431
use vortex_error::VortexError;
2532
use vortex_error::VortexExpect;
2633
use vortex_error::VortexResult;
@@ -100,10 +107,7 @@ impl DictReader {
100107
)
101108
.vortex_expect("must construct dict values array evaluation")
102109
.map_err(Arc::new)
103-
.map(move |array| {
104-
let array = array?;
105-
Ok(SharedArray::new(array).into_array())
106-
})
110+
.map(move |array| Ok(SharedArray::new(array?).into_array()))
107111
.boxed()
108112
.shared()
109113
})
@@ -155,6 +159,49 @@ impl DictReader {
155159
}
156160
}
157161

162+
fn references_root(expr: &Expression) -> bool {
163+
is_root(expr) || expr.children().iter().any(references_root)
164+
}
165+
166+
/// Split expression into two parts:
167+
///
168+
/// left is the optional outer part that we want to apply to array after
169+
/// canonicalizing.
170+
/// right is the optional inner part that we want to apply to array before
171+
/// canonicalizing.
172+
///
173+
/// We want to push to array only if expression has a negative cost, is
174+
/// infallible and null-insensitive.
175+
fn split_expression_for_pushdown(expr: Expression) -> (Option<Expression>, Option<Expression>) {
176+
let labelled_expr = expr.clone();
177+
let fallible = label_is_fallible(&labelled_expr);
178+
let null_sensitive = label_null_sensitive(&labelled_expr);
179+
let mut inner: Option<Expression> = None;
180+
181+
let outer = expr
182+
.transform_down(|node| {
183+
if is_negative_cost(node.id())
184+
&& references_root(&node)
185+
&& !fallible.get(&node).copied().unwrap_or(true)
186+
&& !null_sensitive.get(&node).copied().unwrap_or(true)
187+
{
188+
inner = Some(node);
189+
Ok(Transformed {
190+
value: root(),
191+
changed: true,
192+
order: TraversalOrder::Skip,
193+
})
194+
} else {
195+
Ok(Transformed::no(node))
196+
}
197+
})
198+
.vortex_expect("infallible")
199+
.into_inner();
200+
201+
let outer = (!is_root(&outer)).then_some(outer);
202+
(outer, inner)
203+
}
204+
158205
impl LayoutReader for DictReader {
159206
fn name(&self) -> &Arc<str> {
160207
&self.name
@@ -229,13 +276,18 @@ impl LayoutReader for DictReader {
229276
mask: MaskFuture,
230277
) -> VortexResult<BoxFuture<'static, VortexResult<ArrayRef>>> {
231278
// TODO: fix up expr partitioning with fallible & null sensitive annotations
232-
let values_eval = self.values_array();
233279
let codes_eval = self
234280
.codes
235281
.projection_evaluation(row_range, &root(), mask)
236282
.map_err(|err| err.with_context("While evaluating projection on codes"))?;
237-
let expr = expr.clone();
238283

284+
let (expr_outer, expr_inner) = split_expression_for_pushdown(expr.clone());
285+
286+
let values_eval = if let Some(inner) = expr_inner {
287+
self.values_eval(inner)
288+
} else {
289+
self.values_array()
290+
};
239291
let all_values_referenced = self.layout.has_all_values_referenced();
240292
Ok(async move {
241293
let (values, codes) = try_join!(values_eval.map_err(VortexError::from), codes_eval)?;
@@ -252,7 +304,11 @@ impl LayoutReader for DictReader {
252304
.into_array()
253305
.optimize()?;
254306

255-
array.apply(&expr)
307+
if let Some(expr) = expr_outer {
308+
array.apply(&expr)
309+
} else {
310+
Ok(array)
311+
}
256312
}
257313
.boxed())
258314
}
@@ -281,11 +337,20 @@ mod tests {
281337
use vortex_array::dtype::FieldName;
282338
use vortex_array::dtype::FieldNames;
283339
use vortex_array::dtype::Nullability;
340+
use vortex_array::dtype::PType;
341+
use vortex_array::expr::Expression;
342+
use vortex_array::expr::byte_length;
343+
use vortex_array::expr::cast;
284344
use vortex_array::expr::eq;
285345
use vortex_array::expr::is_not_null;
346+
use vortex_array::expr::is_root;
347+
use vortex_array::expr::like;
286348
use vortex_array::expr::lit;
287349
use vortex_array::expr::pack;
288350
use vortex_array::expr::root;
351+
use vortex_array::expr::traversal::NodeExt;
352+
use vortex_array::expr::traversal::Transformed;
353+
use vortex_array::expr::traversal::TraversalOrder;
289354
use vortex_array::scalar_fn::session::ScalarFnSession;
290355
use vortex_array::session::ArraySession;
291356
use vortex_array::validity::Validity;
@@ -296,6 +361,7 @@ mod tests {
296361
use vortex_io::session::RuntimeSessionExt;
297362
use vortex_session::VortexSession;
298363

364+
use super::split_expression_for_pushdown;
299365
use crate::LayoutId;
300366
use crate::LayoutRef;
301367
use crate::LayoutStrategy;
@@ -542,4 +608,61 @@ mod tests {
542608
assert_arrays_eq!(actual_canonical, expected);
543609
})
544610
}
611+
612+
fn join_split_expr(initial: &Expression, outer: Option<Expression>, inner: Option<Expression>) {
613+
let outer_expr = outer.unwrap_or_else(root);
614+
let inner_expr = inner.unwrap_or_else(root);
615+
let expected = outer_expr
616+
.transform_down(|node| {
617+
if !is_root(&node) {
618+
return Ok(Transformed::no(node));
619+
}
620+
Ok(Transformed {
621+
value: inner_expr.clone(),
622+
changed: true,
623+
order: TraversalOrder::Skip,
624+
})
625+
})
626+
.vortex_expect("infallible");
627+
assert_eq!(&expected.into_inner(), initial);
628+
}
629+
630+
#[test]
631+
fn split_expr_cast_root() {
632+
let (outer, inner) = split_expression_for_pushdown(root());
633+
assert_eq!(outer, None);
634+
assert_eq!(inner, None); // Applying root to array is useless work
635+
}
636+
637+
#[test]
638+
fn split_expr_partial_pushdown() {
639+
let dtype = DType::Primitive(PType::U64, Nullability::NonNullable);
640+
let expr = cast(byte_length(root()), dtype.clone());
641+
let (outer, inner) = split_expression_for_pushdown(expr.clone());
642+
// [0] = cast([1], dtype)
643+
// [1] = byte_length(root)
644+
assert_eq!(outer, Some(cast(root(), dtype)));
645+
assert_eq!(inner, Some(byte_length(root())));
646+
join_split_expr(&expr, outer, inner);
647+
}
648+
649+
#[test]
650+
fn split_expr_full_pushdown() {
651+
let expr = byte_length(root());
652+
let (outer, inner) = split_expression_for_pushdown(expr.clone());
653+
assert_eq!(outer, None);
654+
assert_eq!(inner, Some(byte_length(root())));
655+
join_split_expr(&expr, outer, inner);
656+
}
657+
658+
#[test]
659+
fn split_expr_no_pushdown() {
660+
// We can push down lit(), but it we replace
661+
// lit() with root(), the semantics change.
662+
let expr = like(root(), lit(1u64));
663+
let (outer, inner) = split_expression_for_pushdown(expr.clone());
664+
assert_eq!(outer, Some(expr.clone()));
665+
assert_eq!(inner, None);
666+
join_split_expr(&expr, outer, inner);
667+
}
545668
}

0 commit comments

Comments
 (0)