Skip to content

Commit 3c497e3

Browse files
perf: Remove cast to boolean after comparison in optimizer (#21022)
1 parent c7888de commit 3c497e3

File tree

8 files changed

+144
-46
lines changed

8 files changed

+144
-46
lines changed

Diff for: crates/polars-plan/src/plans/conversion/dsl_to_ir.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ pub(super) fn run_conversion(
117117
) -> PolarsResult<Node> {
118118
let lp_node = ctxt.lp_arena.add(lp);
119119
ctxt.conversion_optimizer
120-
.coerce_types(ctxt.expr_arena, ctxt.lp_arena, lp_node)
120+
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, lp_node)
121121
.map_err(|e| e.context(format!("'{name}' failed").into()))?;
122122

123123
Ok(lp_node)

Diff for: crates/polars-plan/src/plans/conversion/join.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ pub fn resolve_join(
129129
ctxt.conversion_optimizer
130130
.fill_scratch(&left_on, ctxt.expr_arena);
131131
ctxt.conversion_optimizer
132-
.coerce_types(ctxt.expr_arena, ctxt.lp_arena, input_left)
132+
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_left)
133133
.map_err(|e| e.context("'join' failed".into()))?;
134134
ctxt.conversion_optimizer
135135
.fill_scratch(&right_on, ctxt.expr_arena);
136136
ctxt.conversion_optimizer
137-
.coerce_types(ctxt.expr_arena, ctxt.lp_arena, input_right)
137+
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, input_right)
138138
.map_err(|e| e.context("'join' failed".into()))?;
139139

140140
// Re-evaluate because of mutable borrows earlier.
@@ -447,11 +447,20 @@ fn resolve_join_where(
447447
for e in predicates {
448448
let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?;
449449

450+
ctxt.conversion_optimizer
451+
.push_scratch(predicate.node(), ctxt.expr_arena);
452+
450453
let ir = IR::Filter {
451454
input: last_node,
452455
predicate,
453456
};
457+
454458
last_node = ctxt.lp_arena.add(ir);
455459
}
460+
461+
ctxt.conversion_optimizer
462+
.optimize_exprs(ctxt.expr_arena, ctxt.lp_arena, last_node)
463+
.map_err(|e| e.context("'join_where' failed".into()))?;
464+
456465
Ok((last_node, join_node))
457466
}

Diff for: crates/polars-plan/src/plans/conversion/stack_opt.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,19 @@ impl ConversionOptimizer {
6161
}
6262
}
6363

64-
pub(super) fn coerce_types(
64+
/// Optimizes the expressions in the scratch space. This should be called after filling the
65+
/// scratch space with the expressions that you want to optimize.
66+
pub(super) fn optimize_exprs(
6567
&mut self,
6668
expr_arena: &mut Arena<AExpr>,
6769
ir_arena: &mut Arena<IR>,
68-
current_node: Node,
70+
current_ir_node: Node,
6971
) -> PolarsResult<()> {
7072
// Different from the stack-opt in the optimizer phase, this does a single pass until fixed point per expression.
7173

7274
if let Some(rule) = &mut self.check {
73-
while let Some(x) = rule.optimize_plan(ir_arena, expr_arena, current_node)? {
74-
ir_arena.replace(current_node, x);
75+
while let Some(x) = rule.optimize_plan(ir_arena, expr_arena, current_ir_node)? {
76+
ir_arena.replace(current_ir_node, x);
7577
}
7678
}
7779

@@ -85,14 +87,14 @@ impl ConversionOptimizer {
8587

8688
if let Some(rule) = &mut self.simplify {
8789
while let Some(x) =
88-
rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_node)?
90+
rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_ir_node)?
8991
{
9092
expr_arena.replace(current_expr_node, x);
9193
}
9294
}
9395
if let Some(rule) = &mut self.coerce {
9496
while let Some(x) =
95-
rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_node)?
97+
rule.optimize_expr(expr_arena, current_expr_node, ir_arena, current_ir_node)?
9698
{
9799
expr_arena.replace(current_expr_node, x);
98100
}

Diff for: crates/polars-plan/src/plans/conversion/type_coercion/mod.rs

+41-35
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,7 @@ impl OptimizationRule for TypeCoercionRule {
103103
} => {
104104
let input = expr_arena.get(expr);
105105

106-
inline_or_prune_cast(
107-
input,
108-
dtype,
109-
options.strict(),
110-
lp_node,
111-
lp_arena,
112-
expr_arena,
113-
)?
106+
inline_or_prune_cast(input, dtype, options, lp_node, lp_arena, expr_arena)?
114107
},
115108
AExpr::Ternary {
116109
truthy: truthy_node,
@@ -323,7 +316,13 @@ impl OptimizationRule for TypeCoercionRule {
323316
DataType::Categorical(_, _) if dtype.is_string() => {
324317
// pass
325318
},
326-
_ => cast_expr_ir(e, &dtype, &super_type, expr_arena, false)?,
319+
_ => cast_expr_ir(
320+
e,
321+
&dtype,
322+
&super_type,
323+
expr_arena,
324+
CastOptions::NonStrict,
325+
)?,
327326
}
328327
}
329328
}
@@ -355,50 +354,57 @@ impl OptimizationRule for TypeCoercionRule {
355354
fn inline_or_prune_cast(
356355
aexpr: &AExpr,
357356
dtype: &DataType,
358-
strict: bool,
357+
options: CastOptions,
359358
lp_node: Node,
360359
lp_arena: &Arena<IR>,
361360
expr_arena: &Arena<AExpr>,
362361
) -> PolarsResult<Option<AExpr>> {
363362
if !dtype.is_known() {
364363
return Ok(None);
365364
}
366-
let lv = match aexpr {
365+
366+
let out = match aexpr {
367367
// PRUNE
368-
AExpr::BinaryExpr {
369-
op: Operator::LogicalOr | Operator::LogicalAnd,
370-
..
371-
} => {
372-
if let Some(schema) = lp_arena.get(lp_node).input_schema(lp_arena) {
373-
let field = aexpr.to_field(&schema, Context::Default, expr_arena)?;
374-
if field.dtype == *dtype {
375-
return Ok(Some(aexpr.clone()));
376-
}
368+
AExpr::BinaryExpr { op, .. } => {
369+
use Operator::*;
370+
371+
match op {
372+
LogicalOr | LogicalAnd => {
373+
if let Some(schema) = lp_arena.get(lp_node).input_schema(lp_arena) {
374+
let field = aexpr.to_field(&schema, Context::Default, expr_arena)?;
375+
if field.dtype == *dtype {
376+
return Ok(Some(aexpr.clone()));
377+
}
378+
}
379+
380+
None
381+
},
382+
Eq | EqValidity | NotEq | NotEqValidity | Lt | LtEq | Gt | GtEq => {
383+
if dtype.is_bool() {
384+
Some(aexpr.clone())
385+
} else {
386+
None
387+
}
388+
},
389+
_ => None,
377390
}
378-
return Ok(None);
379391
},
380392
// INLINE
381-
AExpr::Literal(lv) => match try_inline_literal_cast(lv, dtype, strict)? {
382-
None => return Ok(None),
383-
Some(lv) => lv,
384-
},
385-
_ => return Ok(None),
393+
AExpr::Literal(lv) => try_inline_literal_cast(lv, dtype, options)?.map(AExpr::Literal),
394+
_ => None,
386395
};
387-
Ok(Some(AExpr::Literal(lv)))
396+
397+
Ok(out)
388398
}
389399

390400
fn try_inline_literal_cast(
391401
lv: &LiteralValue,
392402
dtype: &DataType,
393-
strict: bool,
403+
options: CastOptions,
394404
) -> PolarsResult<Option<LiteralValue>> {
395405
let lv = match lv {
396406
LiteralValue::Series(s) => {
397-
let s = if strict {
398-
s.strict_cast(dtype)
399-
} else {
400-
s.cast(dtype)
401-
}?;
407+
let s = s.cast_with_options(dtype, options)?;
402408
LiteralValue::Series(SpecialEq::new(s))
403409
},
404410
LiteralValue::StrCat(s) => {
@@ -477,7 +483,7 @@ fn cast_expr_ir(
477483
from_dtype: &DataType,
478484
to_dtype: &DataType,
479485
expr_arena: &mut Arena<AExpr>,
480-
strict: bool,
486+
options: CastOptions,
481487
) -> PolarsResult<()> {
482488
if from_dtype == to_dtype {
483489
return Ok(());
@@ -486,7 +492,7 @@ fn cast_expr_ir(
486492
check_cast(from_dtype, to_dtype)?;
487493

488494
if let AExpr::Literal(lv) = expr_arena.get(e.node()) {
489-
if let Some(literal) = try_inline_literal_cast(lv, to_dtype, strict)? {
495+
if let Some(literal) = try_inline_literal_cast(lv, to_dtype, options)? {
490496
e.set_node(expr_arena.add(AExpr::Literal(literal)));
491497
e.set_dtype(to_dtype.clone());
492498
return Ok(());

Diff for: py-polars/tests/unit/lazyframe/test_lazyframe.py

+26
Original file line numberDiff line numberDiff line change
@@ -1462,3 +1462,29 @@ def test_lf_unnest() -> None:
14621462
]
14631463
)
14641464
assert_frame_equal(lf.unnest("a", "b").collect(), expected)
1465+
1466+
1467+
def test_type_coercion_cast_boolean_after_comparison() -> None:
1468+
import operator
1469+
1470+
lf = pl.LazyFrame({"a": 1, "b": 2})
1471+
1472+
for op in [
1473+
operator.eq,
1474+
operator.ne,
1475+
operator.lt,
1476+
operator.le,
1477+
operator.gt,
1478+
operator.ge,
1479+
pl.Expr.eq_missing,
1480+
pl.Expr.ne_missing,
1481+
]:
1482+
e = op(pl.col("a"), pl.col("b")).cast(pl.Boolean).alias("o")
1483+
assert "cast" not in lf.with_columns(e).explain()
1484+
1485+
e = op(pl.col("a"), pl.col("b")).cast(pl.Boolean).cast(pl.Boolean).alias("o")
1486+
assert "cast" not in lf.with_columns(e).explain()
1487+
1488+
for op in [operator.and_, operator.or_, operator.xor]:
1489+
e = op(pl.col("a"), pl.col("b")).cast(pl.Boolean)
1490+
assert "cast" in lf.with_columns(e).explain()

Diff for: py-polars/tests/unit/operations/test_cast.py

+16
Original file line numberDiff line numberDiff line change
@@ -698,3 +698,19 @@ def test_cast_python_dtypes() -> None:
698698
assert s.cast(float).dtype == pl.Float64
699699
assert s.cast(bool).dtype == pl.Boolean
700700
assert s.cast(str).dtype == pl.String
701+
702+
703+
def test_overflowing_cast_literals_21023() -> None:
704+
for no_optimization in [True, False]:
705+
assert_frame_equal(
706+
(
707+
pl.LazyFrame()
708+
.select(
709+
pl.lit(pl.Series([128], dtype=pl.Int64)).cast(
710+
pl.Int8, wrap_numerical=True
711+
)
712+
)
713+
.collect(no_optimization=no_optimization)
714+
),
715+
pl.Series([-128], dtype=pl.Int8).to_frame(),
716+
)

Diff for: py-polars/tests/unit/operations/test_join.py

+40
Original file line numberDiff line numberDiff line change
@@ -1440,3 +1440,43 @@ def test_no_collapse_join_when_maintain_order_20725() -> None:
14401440
df_pl_eager = ldf.collect().filter(pl.col("Fraction_1") == 100)
14411441

14421442
assert_frame_equal(df_pl_lazy, df_pl_eager)
1443+
1444+
1445+
def test_join_where_predicate_type_coercion_21009() -> None:
1446+
left_frame = pl.LazyFrame(
1447+
{
1448+
"left_match": ["A", "B", "C", "D", "E", "F"],
1449+
"left_date_start": range(6),
1450+
}
1451+
)
1452+
1453+
right_frame = pl.LazyFrame(
1454+
{
1455+
"right_match": ["D", "E", "F", "G", "H", "I"],
1456+
"right_date": range(6),
1457+
}
1458+
)
1459+
1460+
# Note: Cannot eq the plans as the operand sides are non-deterministic
1461+
1462+
q1 = left_frame.join_where(
1463+
right_frame,
1464+
pl.col("left_match") == pl.col("right_match"),
1465+
pl.col("right_date") >= pl.col("left_date_start"),
1466+
)
1467+
1468+
plan = q1.explain().splitlines()
1469+
assert plan[0].strip().startswith("FILTER")
1470+
assert plan[1].strip().startswith("INNER JOIN")
1471+
1472+
q2 = left_frame.join_where(
1473+
right_frame,
1474+
pl.all_horizontal(pl.col("left_match") == pl.col("right_match")),
1475+
pl.col("right_date") >= pl.col("left_date_start"),
1476+
)
1477+
1478+
plan = q2.explain().splitlines()
1479+
assert plan[0].strip().startswith("FILTER")
1480+
assert plan[1].strip().startswith("INNER JOIN")
1481+
1482+
assert_frame_equal(q1.collect(), q2.collect())

Diff for: py-polars/tests/unit/test_cwc.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,7 @@ def test_cwc_with_internal_aliases() -> None:
146146
explain = df.explain()
147147

148148
assert (
149-
"""[[(col("a")) == (2)].cast(Boolean).alias("c"), [(col("b")) * (3)].alias("d")]"""
150-
in explain
149+
"""[[(col("a")) == (2)].alias("c"), [(col("b")) * (3)].alias("d")]""" in explain
151150
)
152151

153152

0 commit comments

Comments
 (0)