Skip to content

Commit 861e3e3

Browse files
authored
fix(rust, python): ensure that no-projection is seen as select all in… (#5356)
1 parent 0f207f0 commit 861e3e3

File tree

4 files changed

+100
-14
lines changed

4 files changed

+100
-14
lines changed

polars/polars-core/src/datatypes/aliases.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,15 @@ impl<K> InitHashMaps for PlIndexSet<K> {
6969
Self::with_capacity_and_hasher(capacity, Default::default())
7070
}
7171
}
72+
73+
impl<K, V> InitHashMaps for PlIndexMap<K, V> {
74+
type HashMap = Self;
75+
76+
fn new() -> Self::HashMap {
77+
Self::with_capacity_and_hasher(0, Default::default())
78+
}
79+
80+
fn with_capacity(capacity: usize) -> Self::HashMap {
81+
Self::with_capacity_and_hasher(capacity, Default::default())
82+
}
83+
}

polars/polars-lazy/polars-plan/src/logical_plan/optimizer/cache_states.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ fn get_upper_projections(
1515
use ALogicalPlan::*;
1616
// during projection pushdown all accumulated p
1717
match parent {
18-
Projection { expr, .. } => {
18+
Projection { expr, .. } | HStack { exprs: expr, .. } => {
1919
let mut out = Vec::with_capacity(expr.len());
2020
for node in expr {
2121
out.extend(aexpr_to_leaf_names_iter(*node, expr_arena));
@@ -92,12 +92,9 @@ pub(super) fn set_cache_states(
9292
if let Some(names) = get_upper_projections(parent_node, lp_arena, expr_arena) {
9393
entry.1.extend(names);
9494
}
95-
// if there is no projection above, it maybe that the
96-
// cache is underneath another cache and projection pushdown never reached it.
97-
// other trails may take care of that cache
98-
// if there is no other cache above, then there was no projection and we must take
95+
// There was no projection and we must take
9996
// all columns
100-
else if previous_cache.is_none() {
97+
else {
10198
let schema = lp.schema(lp_arena);
10299
entry
103100
.1

polars/polars-lazy/polars-plan/src/logical_plan/optimizer/cse.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,14 @@ fn lp_node_equal(a: &ALogicalPlan, b: &ALogicalPlan, expr_arena: &Arena<AExpr>)
239239
}
240240

241241
/// Iterate from two leaf location upwards and find the latest matching node.
242+
///
243+
/// Returns the matching nodes
242244
fn longest_subgraph(
243245
trail_a: &Trail,
244246
trail_b: &Trail,
245247
lp_arena: &Arena<ALogicalPlan>,
246248
expr_arena: &Arena<AExpr>,
247-
) -> Option<(usize, Node, Node)> {
249+
) -> Option<(Node, Node)> {
248250
if trail_a.is_empty() || trail_b.is_empty() {
249251
return None;
250252
}
@@ -276,7 +278,7 @@ fn longest_subgraph(
276278
}
277279
// previous node was equal
278280
if i > 0 {
279-
Some((i - 1, prev_node_a, prev_node_b))
281+
Some((prev_node_a, prev_node_b))
280282
} else {
281283
None
282284
}
@@ -317,14 +319,14 @@ pub(crate) fn elim_cmn_subplans(
317319
let mut changed = false;
318320

319321
let mut cache_mapping = BTreeMap::new();
322+
let mut cache_counts = PlHashMap::with_capacity(trail_ends.len());
320323

321-
// insert cache nodes
322324
for combination in trail_ends.iter() {
323325
// both are the same, but only point to a different location
324326
// in our arena so we hash one and store the hash for both locations
325327
// this will ensure all matches have the same hash.
326-
let node1 = combination.1 .0;
327-
let node2 = combination.2 .0;
328+
let node1 = combination.0 .0;
329+
let node2 = combination.1 .0;
328330

329331
let cache_id = match (cache_mapping.get(&node1), cache_mapping.get(&node2)) {
330332
(Some(h), _) => *h,
@@ -345,12 +347,31 @@ pub(crate) fn elim_cmn_subplans(
345347
cache_id
346348
}
347349
};
350+
*cache_counts.entry(cache_id).or_insert(0usize) += 1;
351+
}
352+
353+
// insert cache nodes
354+
for combination in trail_ends.iter() {
355+
// both are the same, but only point to a different location
356+
// in our arena so we hash one and store the hash for both locations
357+
// this will ensure all matches have the same hash.
358+
let node1 = combination.0 .0;
359+
let node2 = combination.1 .0;
360+
361+
let cache_id = match (cache_mapping.get(&node1), cache_mapping.get(&node2)) {
362+
(Some(h), _) => *h,
363+
(_, Some(h)) => *h,
364+
_ => {
365+
unreachable!()
366+
}
367+
};
368+
let cache_count = *cache_counts.get(&cache_id).unwrap();
348369

349370
// reassign old nodes to another location as we are going to replace
350371
// them with a cache node
351-
for inp_node in [combination.1, combination.2] {
372+
for inp_node in [combination.0, combination.1] {
352373
if let ALogicalPlan::Cache { count, .. } = lp_arena.get_mut(inp_node) {
353-
*count += 1;
374+
*count = cache_count;
354375
} else {
355376
let lp = lp_arena.get(inp_node).clone();
356377

@@ -360,7 +381,7 @@ pub(crate) fn elim_cmn_subplans(
360381
input: node,
361382
id: cache_id,
362383
// remove after one cache hit.
363-
count: 1,
384+
count: cache_count,
364385
};
365386
lp_arena.replace(inp_node, cache_lp.clone());
366387
};

polars/polars-lazy/src/tests/cse.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,59 @@ fn test_cse_joins_4954() -> PolarsResult<()> {
198198

199199
Ok(())
200200
}
201+
#[test]
202+
#[cfg(feature = "semi_anti_join")]
203+
fn test_cache_with_partial_projection() -> PolarsResult<()> {
204+
let lf1 = df![
205+
"id" => ["a"],
206+
"x" => [1],
207+
"freq" => [2]
208+
]?
209+
.lazy();
210+
211+
let lf2 = df![
212+
"id" => ["a"]
213+
]?
214+
.lazy();
215+
216+
let q = lf2
217+
.join(
218+
lf1.clone().select([col("id"), col("freq")]),
219+
[col("id")],
220+
[col("id")],
221+
JoinType::Semi,
222+
)
223+
.join(
224+
lf1.clone().filter(col("x").neq(lit(8))),
225+
[col("id")],
226+
[col("id")],
227+
JoinType::Semi,
228+
)
229+
.join(
230+
lf1.clone().filter(col("x").neq(lit(8))),
231+
[col("id")],
232+
[col("id")],
233+
JoinType::Semi,
234+
);
235+
236+
let q = q.with_common_subplan_elimination(true);
237+
238+
let (mut expr_arena, mut lp_arena) = get_arenas();
239+
let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap();
240+
241+
// ensure we get two different caches
242+
// and ensure that every cache only has 1 hit.
243+
let cache_ids = (&lp_arena)
244+
.iter(lp)
245+
.flat_map(|(_, lp)| {
246+
use ALogicalPlan::*;
247+
match lp {
248+
Cache { id, .. } => Some(*id),
249+
_ => None,
250+
}
251+
})
252+
.collect::<BTreeSet<_>>();
253+
assert_eq!(cache_ids.len(), 2);
254+
255+
Ok(())
256+
}

0 commit comments

Comments
 (0)