Skip to content

Commit 7364fa8

Browse files
committed
vortex-array: stats rewrite predicate satisfiers
Add satisfier rewrites for binary comparisons, boolean composition, and between expressions. The rewrites prove predicates true from min/max stats and add null/NaN guards where comparisons require every row to evaluate true. Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 0a41704 commit 7364fa8

1 file changed

Lines changed: 184 additions & 0 deletions

File tree

vortex-array/src/stats/rewrite/builtins.rs

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,63 @@ impl StatsRewriteRule for BinaryStatsRewrite {
130130
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
131131
})
132132
}
133+
134+
fn satisfy(
135+
&self,
136+
expr: &Expression,
137+
ctx: &StatsRewriteCtx<'_>,
138+
) -> VortexResult<Option<Expression>> {
139+
let operator = expr.as_::<Binary>();
140+
let lhs = expr.child(0);
141+
let rhs = expr.child(1);
142+
143+
// Min/max stats may be truncated to outward bounds (stored min ≤
144+
// true min, stored max ≥ true max), which keeps every comparison
145+
// below conservative: a bound that proves the predicate proves it
146+
// for the true extrema too.
147+
let value_predicate = match operator {
148+
// Both value ranges pinch to the same single point.
149+
Operator::Eq => min(lhs)
150+
.zip(max(lhs))
151+
.zip(min(rhs).zip(max(rhs)))
152+
.map(|((min_lhs, max_lhs), (min_rhs, max_rhs))| {
153+
and(lt_eq(max_lhs, min_rhs), gt_eq(min_lhs, max_rhs))
154+
}),
155+
// The value ranges are disjoint.
156+
Operator::NotEq => max(lhs)
157+
.zip(min(rhs))
158+
.zip(min(lhs).zip(max(rhs)))
159+
.map(|((max_lhs, min_rhs), (min_lhs, max_rhs))| {
160+
or(lt(max_lhs, min_rhs), gt(min_lhs, max_rhs))
161+
}),
162+
Operator::Gt => min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b)),
163+
Operator::Gte => min(lhs).zip(max(rhs)).map(|(a, b)| gt_eq(a, b)),
164+
Operator::Lt => max(lhs).zip(min(rhs)).map(|(a, b)| lt(a, b)),
165+
Operator::Lte => max(lhs).zip(min(rhs)).map(|(a, b)| lt_eq(a, b)),
166+
Operator::And => {
167+
return Ok(match (ctx.satisfy(lhs)?, ctx.satisfy(rhs)?) {
168+
(Some(lhs), Some(rhs)) => Some(and(lhs, rhs)),
169+
_ => None,
170+
});
171+
}
172+
Operator::Or => {
173+
let lhs_satisfier = ctx.satisfy(lhs)?;
174+
let rhs_satisfier = ctx.satisfy(rhs)?;
175+
return Ok(or_collect(lhs_satisfier.into_iter().chain(rhs_satisfier)));
176+
}
177+
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => return Ok(None),
178+
};
179+
value_predicate
180+
.map(|value_predicate| {
181+
// Satisfaction must prove more than the values: a NaN
182+
// operand makes every comparison false, and a null operand
183+
// makes it null — neither is `true`, so the rewrite is only
184+
// sound over rows proven non-NaN and non-null.
185+
let guarded = with_nan_predicate(ctx, lhs, rhs, value_predicate)?;
186+
with_all_non_null_predicate(ctx, [lhs, rhs], guarded)
187+
})
188+
.transpose()
189+
}
133190
}
134191

135192
#[derive(Debug)]
@@ -154,6 +211,21 @@ impl StatsRewriteRule for BetweenStatsRewrite {
154211
let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
155212
ctx.falsify(&and(lhs, rhs))
156213
}
214+
215+
fn satisfy(
216+
&self,
217+
expr: &Expression,
218+
ctx: &StatsRewriteCtx<'_>,
219+
) -> VortexResult<Option<Expression>> {
220+
let options = expr.as_::<Between>();
221+
let arr = expr.child(0).clone();
222+
let lower = expr.child(1).clone();
223+
let upper = expr.child(2).clone();
224+
225+
let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
226+
let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
227+
ctx.satisfy(&and(lhs, rhs))
228+
}
157229
}
158230

159231
#[derive(Debug)]
@@ -503,6 +575,27 @@ fn with_nan_predicate(
503575
with_all_non_nan_predicate(ctx, [lhs, rhs], value_predicate)
504576
}
505577

578+
// Satisfaction rewrites prove a predicate true for *every* row, but min/max
579+
// stats describe non-null values only: a null operand row evaluates the
580+
// comparison to null, not true. Guard each nullable operand with an
581+
// all-non-null check; non-nullable operands need none.
582+
fn with_all_non_null_predicate<'a>(
583+
ctx: &StatsRewriteCtx<'_>,
584+
exprs: impl IntoIterator<Item = &'a Expression>,
585+
value_predicate: Expression,
586+
) -> VortexResult<Expression> {
587+
let mut null_checks = Vec::new();
588+
for expr in exprs {
589+
if ctx.return_dtype(expr)?.is_nullable() {
590+
null_checks.push(all_non_null(expr));
591+
}
592+
}
593+
Ok(match and_collect(null_checks) {
594+
Some(null_check) => and(null_check, value_predicate),
595+
None => value_predicate,
596+
})
597+
}
598+
506599
fn with_all_non_nan_predicate<'a>(
507600
ctx: &StatsRewriteCtx<'_>,
508601
exprs: impl IntoIterator<Item = &'a Expression>,
@@ -626,6 +719,7 @@ mod tests {
626719
("f", DType::Primitive(PType::F32, Nullability::NonNullable)),
627720
("s", DType::Utf8(Nullability::NonNullable)),
628721
("t", DType::Utf8(Nullability::NonNullable)),
722+
("n", DType::Primitive(PType::I32, Nullability::Nullable)),
629723
]),
630724
Nullability::NonNullable,
631725
)
@@ -671,6 +765,96 @@ mod tests {
671765
Ok(())
672766
}
673767

768+
#[test]
769+
fn rewrites_comparison_satisfier() -> VortexResult<()> {
770+
// Non-nullable integer: the value condition alone proves all-true.
771+
let expr = lt(col("a"), lit(10));
772+
assert_eq!(
773+
satisfy(&expr)?,
774+
Some(lt(stat(col("a"), Stat::Max), lit(10)))
775+
);
776+
777+
let expr = gt_eq(col("a"), lit(10));
778+
assert_eq!(
779+
satisfy(&expr)?,
780+
Some(gt_eq(stat(col("a"), Stat::Min), lit(10)))
781+
);
782+
783+
// Column-to-column comparison uses both sides' stats.
784+
let expr = lt(col("a"), col("b"));
785+
assert_eq!(
786+
satisfy(&expr)?,
787+
Some(lt(stat(col("a"), Stat::Max), stat(col("b"), Stat::Min)))
788+
);
789+
790+
// Floats must also prove no NaNs: a NaN row never satisfies a
791+
// comparison.
792+
let expr = gt(col("f"), lit(1.0f32));
793+
assert_eq!(
794+
satisfy(&expr)?,
795+
Some(and(
796+
nan_free(col("f")),
797+
gt(stat(col("f"), Stat::Min), lit(1.0f32))
798+
))
799+
);
800+
801+
// Nullable operands must also prove no nulls: a null row evaluates
802+
// the comparison to null, not true.
803+
let expr = lt(col("n"), lit(10));
804+
assert_eq!(
805+
satisfy(&expr)?,
806+
Some(and(
807+
all_non_null(&col("n")),
808+
lt(stat(col("n"), Stat::Max), lit(10))
809+
))
810+
);
811+
Ok(())
812+
}
813+
814+
#[test]
815+
fn rewrites_boolean_satisfiers() -> VortexResult<()> {
816+
// Conjunctions require both satisfiers; disjunctions accept either.
817+
let expr = and(gt(col("a"), lit(10)), lt(col("a"), lit(50)));
818+
assert_eq!(
819+
satisfy(&expr)?,
820+
Some(and(
821+
gt(stat(col("a"), Stat::Min), lit(10)),
822+
lt(stat(col("a"), Stat::Max), lit(50)),
823+
))
824+
);
825+
826+
let expr = or(gt(col("a"), lit(10)), lt(col("a"), lit(0)));
827+
assert_eq!(
828+
satisfy(&expr)?,
829+
Some(or(
830+
gt(stat(col("a"), Stat::Min), lit(10)),
831+
lt(stat(col("a"), Stat::Max), lit(0)),
832+
))
833+
);
834+
Ok(())
835+
}
836+
837+
#[test]
838+
fn rewrites_between_satisfier() -> VortexResult<()> {
839+
let expr = between(
840+
col("a"),
841+
lit(10),
842+
lit(50),
843+
BetweenOptions {
844+
lower_strict: StrictComparison::NonStrict,
845+
upper_strict: StrictComparison::NonStrict,
846+
},
847+
);
848+
assert_eq!(
849+
satisfy(&expr)?,
850+
Some(and(
851+
lt_eq(lit(10), stat(col("a"), Stat::Min)),
852+
lt_eq(stat(col("a"), Stat::Max), lit(50)),
853+
))
854+
);
855+
Ok(())
856+
}
857+
674858
#[test]
675859
fn rewrites_boolean_falsifiers() -> VortexResult<()> {
676860
let expr = and(gt(col("a"), lit(10)), lt(col("a"), lit(50)));

0 commit comments

Comments
 (0)