1
1
use std:: sync:: Arc ;
2
2
3
3
use parking_lot:: Mutex ;
4
- use polars_core:: prelude:: { InitHashMaps , PlHashMap , PlIndexMap } ;
4
+ use polars_core:: prelude:: { InitHashMaps , PlIndexMap } ;
5
5
use polars_core:: schema:: Schema ;
6
6
use polars_error:: { PolarsResult , polars_err} ;
7
7
use polars_expr:: state:: ExecutionState ;
8
8
use polars_mem_engine:: create_physical_plan;
9
9
use polars_plan:: plans:: expr_ir:: { ExprIR , OutputName } ;
10
- use polars_plan:: plans:: { AExpr , ArenaExprIter , DataFrameUdf , IR , IRAggExpr } ;
10
+ use polars_plan:: plans:: { AExpr , DataFrameUdf , IR , IRAggExpr , NaiveExprMerger } ;
11
11
use polars_plan:: prelude:: GroupbyOptions ;
12
12
use polars_utils:: arena:: { Arena , Node } ;
13
- use polars_utils:: itertools:: Itertools ;
14
13
use polars_utils:: pl_str:: PlSmallStr ;
15
14
use polars_utils:: unique_column_name;
16
15
use recursive:: recursive;
17
16
use slotmap:: SlotMap ;
18
17
19
- use super :: lower_expr:: lower_exprs;
20
18
use super :: { ExprCache , PhysNode , PhysNodeKey , PhysNodeKind , PhysStream } ;
21
19
use crate :: physical_plan:: lower_expr:: {
22
- build_select_stream, compute_output_schema, is_fake_elementwise_function, is_input_independent,
20
+ build_select_stream, compute_output_schema, is_elementwise_rec_cached,
21
+ is_fake_elementwise_function, is_input_independent,
23
22
} ;
24
23
use crate :: physical_plan:: lower_ir:: build_slice_stream;
25
24
use crate :: utils:: late_materialized_df:: LateMaterializedDataFrame ;
@@ -77,36 +76,34 @@ fn build_group_by_fallback(
77
76
#[ recursive]
78
77
fn try_lower_elementwise_scalar_agg_expr (
79
78
expr : Node ,
80
- inside_agg : bool ,
81
79
outer_name : Option < PlSmallStr > ,
80
+ expr_merger : & NaiveExprMerger ,
81
+ expr_cache : & mut ExprCache ,
82
82
expr_arena : & mut Arena < AExpr > ,
83
83
agg_exprs : & mut Vec < ExprIR > ,
84
- trans_input_cols : & PlHashMap < PlSmallStr , Node > ,
84
+ uniq_input_exprs : & mut PlIndexMap < u32 , PlSmallStr > ,
85
85
) -> Option < Node > {
86
86
// Helper macro to simplify recursive calls.
87
87
macro_rules! lower_rec {
88
- ( $input: expr, $inside_agg : expr ) => {
88
+ ( $input: expr) => {
89
89
try_lower_elementwise_scalar_agg_expr(
90
90
$input,
91
- $inside_agg,
92
91
None ,
92
+ expr_merger,
93
+ expr_cache,
93
94
expr_arena,
94
95
agg_exprs,
95
- trans_input_cols ,
96
+ uniq_input_exprs ,
96
97
)
97
98
} ;
98
99
}
99
100
100
101
match expr_arena. get ( expr) {
101
102
AExpr :: Alias ( ..) => unreachable ! ( "alias found in physical plan" ) ,
102
103
103
- AExpr :: Column ( c) => {
104
- if inside_agg {
105
- Some ( trans_input_cols[ c] )
106
- } else {
107
- // Implicit implode not yet supported.
108
- None
109
- }
104
+ AExpr :: Column ( _) => {
105
+ // Implicit implode not yet supported.
106
+ None
110
107
} ,
111
108
112
109
AExpr :: Literal ( lit) => {
@@ -131,8 +128,8 @@ fn try_lower_elementwise_scalar_agg_expr(
131
128
132
129
AExpr :: BinaryExpr { left, op, right } => {
133
130
let ( left, op, right) = ( * left, * op, * right) ;
134
- let left = lower_rec ! ( left, inside_agg ) ?;
135
- let right = lower_rec ! ( right, inside_agg ) ?;
131
+ let left = lower_rec ! ( left) ?;
132
+ let right = lower_rec ! ( right) ?;
136
133
Some ( expr_arena. add ( AExpr :: BinaryExpr { left, op, right } ) )
137
134
} ,
138
135
@@ -142,9 +139,9 @@ fn try_lower_elementwise_scalar_agg_expr(
142
139
falsy,
143
140
} => {
144
141
let ( predicate, truthy, falsy) = ( * predicate, * truthy, * falsy) ;
145
- let predicate = lower_rec ! ( predicate, inside_agg ) ?;
146
- let truthy = lower_rec ! ( truthy, inside_agg ) ?;
147
- let falsy = lower_rec ! ( falsy, inside_agg ) ?;
142
+ let predicate = lower_rec ! ( predicate) ?;
143
+ let truthy = lower_rec ! ( truthy) ?;
144
+ let falsy = lower_rec ! ( falsy) ?;
148
145
Some ( expr_arena. add ( AExpr :: Ternary {
149
146
predicate,
150
147
truthy,
@@ -162,7 +159,7 @@ fn try_lower_elementwise_scalar_agg_expr(
162
159
. into_iter ( )
163
160
. map ( |i| {
164
161
// The function may be sensitive to names (e.g. pl.struct), so we restore them.
165
- let new_node = lower_rec ! ( i. node( ) , inside_agg ) ?;
162
+ let new_node = lower_rec ! ( i. node( ) ) ?;
166
163
Some ( ExprIR :: new (
167
164
new_node,
168
165
OutputName :: Alias ( i. output_name ( ) . clone ( ) ) ,
@@ -188,7 +185,7 @@ fn try_lower_elementwise_scalar_agg_expr(
188
185
options,
189
186
} => {
190
187
let ( expr, dtype, options) = ( * expr, dtype. clone ( ) , * options) ;
191
- let expr = lower_rec ! ( expr, inside_agg ) ?;
188
+ let expr = lower_rec ! ( expr) ?;
192
189
Some ( expr_arena. add ( AExpr :: Cast {
193
190
expr,
194
191
dtype,
@@ -197,10 +194,6 @@ fn try_lower_elementwise_scalar_agg_expr(
197
194
} ,
198
195
199
196
AExpr :: Agg ( agg) => {
200
- // Nested aggregates not supported.
201
- if inside_agg {
202
- return None ;
203
- }
204
197
match agg {
205
198
IRAggExpr :: Min { input, .. }
206
199
| IRAggExpr :: Max { input, .. }
@@ -211,15 +204,27 @@ fn try_lower_elementwise_scalar_agg_expr(
211
204
| IRAggExpr :: Var ( input, ..)
212
205
| IRAggExpr :: Std ( input, ..)
213
206
| IRAggExpr :: Count ( input, ..) => {
214
- let orig_agg = agg. clone ( ) ;
215
- // Lower and replace input.
216
- let trans_input = lower_rec ! ( * input, true ) ?;
217
- let mut trans_agg = orig_agg;
218
- trans_agg. set_input ( trans_input) ;
207
+ if is_input_independent ( * input, expr_arena, expr_cache) {
208
+ // TODO: we could simply return expr here, but we first need an is_scalar function, because if
209
+ // it is not a scalar we need to return expr.implode().
210
+ return None ;
211
+ }
212
+
213
+ if !is_elementwise_rec_cached ( * input, expr_arena, expr_cache) {
214
+ return None ;
215
+ }
216
+
217
+ let mut trans_agg = agg. clone ( ) ;
218
+ let input_id = expr_merger. get_uniq_id ( * input) . unwrap ( ) ;
219
+ let input_col = uniq_input_exprs
220
+ . entry ( input_id)
221
+ . or_insert_with ( unique_column_name)
222
+ . clone ( ) ;
223
+ let input_col_node = expr_arena. add ( AExpr :: Column ( input_col. clone ( ) ) ) ;
224
+ trans_agg. set_input ( input_col_node) ;
219
225
let trans_agg_node = expr_arena. add ( AExpr :: Agg ( trans_agg) ) ;
220
226
221
227
// Add to aggregation expressions and replace with a reference to its output.
222
-
223
228
let agg_expr = if let Some ( name) = outer_name {
224
229
ExprIR :: new ( trans_agg_node, OutputName :: Alias ( name) )
225
230
} else {
@@ -284,67 +289,67 @@ fn try_build_streaming_group_by(
284
289
return None ;
285
290
}
286
291
287
- // We must lower the keys together with the input to the aggregations.
288
- let mut input_columns = PlIndexMap :: new ( ) ;
289
- for agg in aggs {
290
- for ( node, expr) in ( & * expr_arena) . iter ( agg. node ( ) ) {
291
- if let AExpr :: Column ( c) = expr {
292
- input_columns. insert ( c. clone ( ) , node) ;
293
- }
294
- }
292
+ // Fill all expressions into the merger, letting us extract common subexpressions later.
293
+ let mut expr_merger = NaiveExprMerger :: default ( ) ;
294
+ for key in keys {
295
+ expr_merger. add_expr ( key. node ( ) , expr_arena) ;
295
296
}
296
-
297
- let mut pre_lower_exprs = keys. to_vec ( ) ;
298
- for ( col, node) in input_columns. iter ( ) {
299
- pre_lower_exprs. push ( ExprIR :: new ( * node, OutputName :: ColumnLhs ( col. clone ( ) ) ) ) ;
297
+ for agg in aggs {
298
+ expr_merger. add_expr ( agg. node ( ) , expr_arena) ;
300
299
}
301
- let Ok ( ( trans_input, trans_exprs) ) =
302
- lower_exprs ( input, & pre_lower_exprs, expr_arena, phys_sm, expr_cache)
303
- else {
304
- return None ;
305
- } ;
306
- let trans_keys = trans_exprs[ ..keys. len ( ) ] . to_vec ( ) ;
307
- let trans_input_cols: PlHashMap < _ , _ > = trans_exprs[ keys. len ( ) ..]
308
- . iter ( )
309
- . zip ( input_columns. into_keys ( ) )
310
- . map ( |( expr, col) | ( col, expr. node ( ) ) )
311
- . collect ( ) ;
312
300
313
- // We must now lower each (presumed) scalar aggregate expression while
314
- // substituting the translated input columns and extracting the aggregate
315
- // expressions.
301
+ // Extract aggregates, input expressions for those aggregates and replace
302
+ // with agg node output columns.
303
+ let mut uniq_input_exprs = PlIndexMap :: new ( ) ;
316
304
let mut trans_agg_exprs = Vec :: new ( ) ;
317
- let mut trans_output_exprs = keys
318
- . iter ( )
319
- . map ( |key| {
320
- let key_node = expr_arena. add ( AExpr :: Column ( key. output_name ( ) . clone ( ) ) ) ;
321
- ExprIR :: from_node ( key_node, expr_arena)
322
- } )
323
- . collect_vec ( ) ;
305
+ let mut trans_keys = Vec :: new ( ) ;
306
+ let mut trans_output_exprs = Vec :: new ( ) ;
307
+ for key in keys {
308
+ let key_id = expr_merger. get_uniq_id ( key. node ( ) ) . unwrap ( ) ;
309
+ let uniq_col = uniq_input_exprs
310
+ . entry ( key_id)
311
+ . or_insert_with ( unique_column_name)
312
+ . clone ( ) ;
313
+ let trans_key_node = expr_arena. add ( AExpr :: Column ( uniq_col) ) ;
314
+ trans_keys. push ( ExprIR :: from_node ( trans_key_node, expr_arena) ) ;
315
+ let output_name = OutputName :: Alias ( key. output_name ( ) . clone ( ) ) ;
316
+ trans_output_exprs. push ( ExprIR :: new ( trans_key_node, output_name) ) ;
317
+ }
324
318
for agg in aggs {
325
319
let trans_node = try_lower_elementwise_scalar_agg_expr (
326
320
agg. node ( ) ,
327
- false ,
328
321
Some ( agg. output_name ( ) . clone ( ) ) ,
322
+ & expr_merger,
323
+ expr_cache,
329
324
expr_arena,
330
325
& mut trans_agg_exprs,
331
- & trans_input_cols ,
326
+ & mut uniq_input_exprs ,
332
327
) ?;
333
328
let output_name = OutputName :: Alias ( agg. output_name ( ) . clone ( ) ) ;
334
329
trans_output_exprs. push ( ExprIR :: new ( trans_node, output_name) ) ;
335
330
}
336
331
337
- let input_schema = & phys_sm[ trans_input. node ] . output_schema ;
332
+ // We must lower the keys together with the input to the aggregations.
333
+ let mut input_exprs = Vec :: new ( ) ;
334
+ for ( uniq_id, name) in uniq_input_exprs. iter ( ) {
335
+ let node = expr_merger. get_node ( * uniq_id) . unwrap ( ) ;
336
+ input_exprs. push ( ExprIR :: new ( node, OutputName :: Alias ( name. clone ( ) ) ) ) ;
337
+ }
338
+
339
+ let pre_select =
340
+ build_select_stream ( input, & input_exprs, expr_arena, phys_sm, expr_cache) . ok ( ) ?;
341
+
342
+ let input_schema = & phys_sm[ pre_select. node ] . output_schema ;
338
343
let group_by_output_schema = compute_output_schema (
339
344
input_schema,
340
- & [ trans_keys. clone ( ) , trans_agg_exprs. clone ( ) ] . concat ( ) ,
345
+ & [ trans_keys. as_slice ( ) , trans_agg_exprs. as_slice ( ) ] . concat ( ) ,
341
346
expr_arena,
342
347
)
343
348
. unwrap ( ) ;
344
349
let agg_node = phys_sm. insert ( PhysNode :: new (
345
350
group_by_output_schema,
346
351
PhysNodeKind :: GroupBy {
347
- input : trans_input ,
352
+ input : pre_select ,
348
353
key : trans_keys,
349
354
aggs : trans_agg_exprs,
350
355
} ,
0 commit comments