Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6a4b489

Browse files
authoredApr 9, 2025··
perf: Add CSE to streaming groupby (#22196)
1 parent 3084f74 commit 6a4b489

File tree

4 files changed

+143
-72
lines changed

4 files changed

+143
-72
lines changed
 

Diff for: ‎crates/polars-plan/src/plans/optimizer/cse/cse_expr.rs

+63
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,69 @@ impl<V> IdentifierMap<V> {
151151
}
152152
}
153153

154+
/// Merges identical expressions into identical IDs.
155+
///
156+
/// Does no analysis whether this leads to legal substitutions.
157+
#[derive(Default)]
158+
pub struct NaiveExprMerger {
159+
node_to_uniq_id: PlHashMap<Node, u32>,
160+
uniq_id_to_node: Vec<Node>,
161+
identifier_to_uniq_id: IdentifierMap<u32>,
162+
arg_stack: Vec<Option<Identifier>>,
163+
}
164+
165+
impl NaiveExprMerger {
166+
pub fn add_expr(&mut self, node: Node, arena: &Arena<AExpr>) {
167+
let node = AexprNode::new(node);
168+
node.visit(self, arena).unwrap();
169+
}
170+
171+
pub fn get_uniq_id(&self, node: Node) -> Option<u32> {
172+
self.node_to_uniq_id.get(&node).copied()
173+
}
174+
175+
pub fn get_node(&self, uniq_id: u32) -> Option<Node> {
176+
self.uniq_id_to_node.get(uniq_id as usize).copied()
177+
}
178+
}
179+
180+
impl Visitor for NaiveExprMerger {
181+
type Node = AexprNode;
182+
type Arena = Arena<AExpr>;
183+
184+
fn pre_visit(
185+
&mut self,
186+
_node: &Self::Node,
187+
_arena: &Self::Arena,
188+
) -> PolarsResult<VisitRecursion> {
189+
self.arg_stack.push(None);
190+
Ok(VisitRecursion::Continue)
191+
}
192+
193+
fn post_visit(
194+
&mut self,
195+
node: &Self::Node,
196+
arena: &Self::Arena,
197+
) -> PolarsResult<VisitRecursion> {
198+
let mut identifier = Identifier::new();
199+
while let Some(Some(arg)) = self.arg_stack.pop() {
200+
identifier.combine(&arg);
201+
}
202+
identifier = identifier.add_ae_node(node, arena);
203+
let uniq_id = *self.identifier_to_uniq_id.entry(
204+
identifier,
205+
|| {
206+
let uniq_id = self.uniq_id_to_node.len() as u32;
207+
self.uniq_id_to_node.push(node.node());
208+
uniq_id
209+
},
210+
arena,
211+
);
212+
self.node_to_uniq_id.insert(node.node(), uniq_id);
213+
Ok(VisitRecursion::Continue)
214+
}
215+
}
216+
154217
/// Identifier maps to Expr Node and count.
155218
type SubExprCount = IdentifierMap<(Node, u32)>;
156219
/// (post_visit_idx, identifier);

Diff for: ‎crates/polars-plan/src/plans/optimizer/cse/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod cse_expr;
22
mod cse_lp;
33

44
pub(super) use cse_expr::CommonSubExprOptimizer;
5+
pub use cse_expr::NaiveExprMerger;
56
pub(super) use cse_lp::{elim_cmn_subplans, prune_unused_caches};
67

78
use super::*;

Diff for: ‎crates/polars-plan/src/plans/optimizer/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ mod slice_pushdown_lp;
2626
mod stack_opt;
2727

2828
use collapse_and_project::SimpleProjectionAndCollapse;
29+
#[cfg(feature = "cse")]
30+
pub use cse::NaiveExprMerger;
2931
use delay_rechunk::DelayRechunk;
3032
use polars_core::config::verbose;
3133
use polars_io::predicates::PhysicalIoExpr;

Diff for: ‎crates/polars-stream/src/physical_plan/lower_group_by.rs

+77-72
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
use std::sync::Arc;
22

33
use parking_lot::Mutex;
4-
use polars_core::prelude::{InitHashMaps, PlHashMap, PlIndexMap};
4+
use polars_core::prelude::{InitHashMaps, PlIndexMap};
55
use polars_core::schema::Schema;
66
use polars_error::{PolarsResult, polars_err};
77
use polars_expr::state::ExecutionState;
88
use polars_mem_engine::create_physical_plan;
99
use polars_plan::plans::expr_ir::{ExprIR, OutputName};
10-
use polars_plan::plans::{AExpr, ArenaExprIter, DataFrameUdf, IR, IRAggExpr};
10+
use polars_plan::plans::{AExpr, DataFrameUdf, IR, IRAggExpr, NaiveExprMerger};
1111
use polars_plan::prelude::GroupbyOptions;
1212
use polars_utils::arena::{Arena, Node};
13-
use polars_utils::itertools::Itertools;
1413
use polars_utils::pl_str::PlSmallStr;
1514
use polars_utils::unique_column_name;
1615
use recursive::recursive;
1716
use slotmap::SlotMap;
1817

19-
use super::lower_expr::lower_exprs;
2018
use super::{ExprCache, PhysNode, PhysNodeKey, PhysNodeKind, PhysStream};
2119
use crate::physical_plan::lower_expr::{
22-
build_select_stream, compute_output_schema, is_fake_elementwise_function, is_input_independent,
20+
build_select_stream, compute_output_schema, is_elementwise_rec_cached,
21+
is_fake_elementwise_function, is_input_independent,
2322
};
2423
use crate::physical_plan::lower_ir::build_slice_stream;
2524
use crate::utils::late_materialized_df::LateMaterializedDataFrame;
@@ -77,36 +76,34 @@ fn build_group_by_fallback(
7776
#[recursive]
7877
fn try_lower_elementwise_scalar_agg_expr(
7978
expr: Node,
80-
inside_agg: bool,
8179
outer_name: Option<PlSmallStr>,
80+
expr_merger: &NaiveExprMerger,
81+
expr_cache: &mut ExprCache,
8282
expr_arena: &mut Arena<AExpr>,
8383
agg_exprs: &mut Vec<ExprIR>,
84-
trans_input_cols: &PlHashMap<PlSmallStr, Node>,
84+
uniq_input_exprs: &mut PlIndexMap<u32, PlSmallStr>,
8585
) -> Option<Node> {
8686
// Helper macro to simplify recursive calls.
8787
macro_rules! lower_rec {
88-
($input:expr, $inside_agg:expr) => {
88+
($input:expr) => {
8989
try_lower_elementwise_scalar_agg_expr(
9090
$input,
91-
$inside_agg,
9291
None,
92+
expr_merger,
93+
expr_cache,
9394
expr_arena,
9495
agg_exprs,
95-
trans_input_cols,
96+
uniq_input_exprs,
9697
)
9798
};
9899
}
99100

100101
match expr_arena.get(expr) {
101102
AExpr::Alias(..) => unreachable!("alias found in physical plan"),
102103

103-
AExpr::Column(c) => {
104-
if inside_agg {
105-
Some(trans_input_cols[c])
106-
} else {
107-
// Implicit implode not yet supported.
108-
None
109-
}
104+
AExpr::Column(_) => {
105+
// Implicit implode not yet supported.
106+
None
110107
},
111108

112109
AExpr::Literal(lit) => {
@@ -131,8 +128,8 @@ fn try_lower_elementwise_scalar_agg_expr(
131128

132129
AExpr::BinaryExpr { left, op, right } => {
133130
let (left, op, right) = (*left, *op, *right);
134-
let left = lower_rec!(left, inside_agg)?;
135-
let right = lower_rec!(right, inside_agg)?;
131+
let left = lower_rec!(left)?;
132+
let right = lower_rec!(right)?;
136133
Some(expr_arena.add(AExpr::BinaryExpr { left, op, right }))
137134
},
138135

@@ -142,9 +139,9 @@ fn try_lower_elementwise_scalar_agg_expr(
142139
falsy,
143140
} => {
144141
let (predicate, truthy, falsy) = (*predicate, *truthy, *falsy);
145-
let predicate = lower_rec!(predicate, inside_agg)?;
146-
let truthy = lower_rec!(truthy, inside_agg)?;
147-
let falsy = lower_rec!(falsy, inside_agg)?;
142+
let predicate = lower_rec!(predicate)?;
143+
let truthy = lower_rec!(truthy)?;
144+
let falsy = lower_rec!(falsy)?;
148145
Some(expr_arena.add(AExpr::Ternary {
149146
predicate,
150147
truthy,
@@ -162,7 +159,7 @@ fn try_lower_elementwise_scalar_agg_expr(
162159
.into_iter()
163160
.map(|i| {
164161
// The function may be sensitive to names (e.g. pl.struct), so we restore them.
165-
let new_node = lower_rec!(i.node(), inside_agg)?;
162+
let new_node = lower_rec!(i.node())?;
166163
Some(ExprIR::new(
167164
new_node,
168165
OutputName::Alias(i.output_name().clone()),
@@ -188,7 +185,7 @@ fn try_lower_elementwise_scalar_agg_expr(
188185
options,
189186
} => {
190187
let (expr, dtype, options) = (*expr, dtype.clone(), *options);
191-
let expr = lower_rec!(expr, inside_agg)?;
188+
let expr = lower_rec!(expr)?;
192189
Some(expr_arena.add(AExpr::Cast {
193190
expr,
194191
dtype,
@@ -197,10 +194,6 @@ fn try_lower_elementwise_scalar_agg_expr(
197194
},
198195

199196
AExpr::Agg(agg) => {
200-
// Nested aggregates not supported.
201-
if inside_agg {
202-
return None;
203-
}
204197
match agg {
205198
IRAggExpr::Min { input, .. }
206199
| IRAggExpr::Max { input, .. }
@@ -211,15 +204,27 @@ fn try_lower_elementwise_scalar_agg_expr(
211204
| IRAggExpr::Var(input, ..)
212205
| IRAggExpr::Std(input, ..)
213206
| IRAggExpr::Count(input, ..) => {
214-
let orig_agg = agg.clone();
215-
// Lower and replace input.
216-
let trans_input = lower_rec!(*input, true)?;
217-
let mut trans_agg = orig_agg;
218-
trans_agg.set_input(trans_input);
207+
if is_input_independent(*input, expr_arena, expr_cache) {
208+
// TODO: we could simply return expr here, but we first need an is_scalar function, because if
209+
// it is not a scalar we need to return expr.implode().
210+
return None;
211+
}
212+
213+
if !is_elementwise_rec_cached(*input, expr_arena, expr_cache) {
214+
return None;
215+
}
216+
217+
let mut trans_agg = agg.clone();
218+
let input_id = expr_merger.get_uniq_id(*input).unwrap();
219+
let input_col = uniq_input_exprs
220+
.entry(input_id)
221+
.or_insert_with(unique_column_name)
222+
.clone();
223+
let input_col_node = expr_arena.add(AExpr::Column(input_col.clone()));
224+
trans_agg.set_input(input_col_node);
219225
let trans_agg_node = expr_arena.add(AExpr::Agg(trans_agg));
220226

221227
// Add to aggregation expressions and replace with a reference to its output.
222-
223228
let agg_expr = if let Some(name) = outer_name {
224229
ExprIR::new(trans_agg_node, OutputName::Alias(name))
225230
} else {
@@ -284,67 +289,67 @@ fn try_build_streaming_group_by(
284289
return None;
285290
}
286291

287-
// We must lower the keys together with the input to the aggregations.
288-
let mut input_columns = PlIndexMap::new();
289-
for agg in aggs {
290-
for (node, expr) in (&*expr_arena).iter(agg.node()) {
291-
if let AExpr::Column(c) = expr {
292-
input_columns.insert(c.clone(), node);
293-
}
294-
}
292+
// Fill all expressions into the merger, letting us extract common subexpressions later.
293+
let mut expr_merger = NaiveExprMerger::default();
294+
for key in keys {
295+
expr_merger.add_expr(key.node(), expr_arena);
295296
}
296-
297-
let mut pre_lower_exprs = keys.to_vec();
298-
for (col, node) in input_columns.iter() {
299-
pre_lower_exprs.push(ExprIR::new(*node, OutputName::ColumnLhs(col.clone())));
297+
for agg in aggs {
298+
expr_merger.add_expr(agg.node(), expr_arena);
300299
}
301-
let Ok((trans_input, trans_exprs)) =
302-
lower_exprs(input, &pre_lower_exprs, expr_arena, phys_sm, expr_cache)
303-
else {
304-
return None;
305-
};
306-
let trans_keys = trans_exprs[..keys.len()].to_vec();
307-
let trans_input_cols: PlHashMap<_, _> = trans_exprs[keys.len()..]
308-
.iter()
309-
.zip(input_columns.into_keys())
310-
.map(|(expr, col)| (col, expr.node()))
311-
.collect();
312300

313-
// We must now lower each (presumed) scalar aggregate expression while
314-
// substituting the translated input columns and extracting the aggregate
315-
// expressions.
301+
// Extract aggregates, input expressions for those aggregates and replace
302+
// with agg node output columns.
303+
let mut uniq_input_exprs = PlIndexMap::new();
316304
let mut trans_agg_exprs = Vec::new();
317-
let mut trans_output_exprs = keys
318-
.iter()
319-
.map(|key| {
320-
let key_node = expr_arena.add(AExpr::Column(key.output_name().clone()));
321-
ExprIR::from_node(key_node, expr_arena)
322-
})
323-
.collect_vec();
305+
let mut trans_keys = Vec::new();
306+
let mut trans_output_exprs = Vec::new();
307+
for key in keys {
308+
let key_id = expr_merger.get_uniq_id(key.node()).unwrap();
309+
let uniq_col = uniq_input_exprs
310+
.entry(key_id)
311+
.or_insert_with(unique_column_name)
312+
.clone();
313+
let trans_key_node = expr_arena.add(AExpr::Column(uniq_col));
314+
trans_keys.push(ExprIR::from_node(trans_key_node, expr_arena));
315+
let output_name = OutputName::Alias(key.output_name().clone());
316+
trans_output_exprs.push(ExprIR::new(trans_key_node, output_name));
317+
}
324318
for agg in aggs {
325319
let trans_node = try_lower_elementwise_scalar_agg_expr(
326320
agg.node(),
327-
false,
328321
Some(agg.output_name().clone()),
322+
&expr_merger,
323+
expr_cache,
329324
expr_arena,
330325
&mut trans_agg_exprs,
331-
&trans_input_cols,
326+
&mut uniq_input_exprs,
332327
)?;
333328
let output_name = OutputName::Alias(agg.output_name().clone());
334329
trans_output_exprs.push(ExprIR::new(trans_node, output_name));
335330
}
336331

337-
let input_schema = &phys_sm[trans_input.node].output_schema;
332+
// We must lower the keys together with the input to the aggregations.
333+
let mut input_exprs = Vec::new();
334+
for (uniq_id, name) in uniq_input_exprs.iter() {
335+
let node = expr_merger.get_node(*uniq_id).unwrap();
336+
input_exprs.push(ExprIR::new(node, OutputName::Alias(name.clone())));
337+
}
338+
339+
let pre_select =
340+
build_select_stream(input, &input_exprs, expr_arena, phys_sm, expr_cache).ok()?;
341+
342+
let input_schema = &phys_sm[pre_select.node].output_schema;
338343
let group_by_output_schema = compute_output_schema(
339344
input_schema,
340-
&[trans_keys.clone(), trans_agg_exprs.clone()].concat(),
345+
&[trans_keys.as_slice(), trans_agg_exprs.as_slice()].concat(),
341346
expr_arena,
342347
)
343348
.unwrap();
344349
let agg_node = phys_sm.insert(PhysNode::new(
345350
group_by_output_schema,
346351
PhysNodeKind::GroupBy {
347-
input: trans_input,
352+
input: pre_select,
348353
key: trans_keys,
349354
aggs: trans_agg_exprs,
350355
},

0 commit comments

Comments
 (0)
Please sign in to comment.