Skip to content

Commit 84bfe64

Browse files
committed
vortex-array: stats rewrite predicate satisfiers
Add satisfier rewrites for binary comparisons, boolean composition, and between expressions. The rewrites prove predicates true from min/max stats and add null/NaN guards where comparisons require every row to evaluate true. Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 0a41704 commit 84bfe64

1 file changed

Lines changed: 182 additions & 0 deletions

File tree

vortex-array/src/stats/rewrite/builtins.rs

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,61 @@ impl StatsRewriteRule for BinaryStatsRewrite {
130130
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
131131
})
132132
}
133+
134+
fn satisfy(
135+
&self,
136+
expr: &Expression,
137+
ctx: &StatsRewriteCtx<'_>,
138+
) -> VortexResult<Option<Expression>> {
139+
let operator = expr.as_::<Binary>();
140+
let lhs = expr.child(0);
141+
let rhs = expr.child(1);
142+
143+
// Min/max stats may be truncated to outward bounds (stored min ≤
144+
// true min, stored max ≥ true max), which keeps every comparison
145+
// below conservative: a bound that proves the predicate proves it
146+
// for the true extrema too.
147+
let value_predicate = match operator {
148+
// Both value ranges pinch to the same single point.
149+
Operator::Eq => min(lhs).zip(max(lhs)).zip(min(rhs).zip(max(rhs))).map(
150+
|((min_lhs, max_lhs), (min_rhs, max_rhs))| {
151+
and(lt_eq(max_lhs, min_rhs), gt_eq(min_lhs, max_rhs))
152+
},
153+
),
154+
// The value ranges are disjoint.
155+
Operator::NotEq => max(lhs).zip(min(rhs)).zip(min(lhs).zip(max(rhs))).map(
156+
|((max_lhs, min_rhs), (min_lhs, max_rhs))| {
157+
or(lt(max_lhs, min_rhs), gt(min_lhs, max_rhs))
158+
},
159+
),
160+
Operator::Gt => min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b)),
161+
Operator::Gte => min(lhs).zip(max(rhs)).map(|(a, b)| gt_eq(a, b)),
162+
Operator::Lt => max(lhs).zip(min(rhs)).map(|(a, b)| lt(a, b)),
163+
Operator::Lte => max(lhs).zip(min(rhs)).map(|(a, b)| lt_eq(a, b)),
164+
Operator::And => {
165+
return Ok(match (ctx.satisfy(lhs)?, ctx.satisfy(rhs)?) {
166+
(Some(lhs), Some(rhs)) => Some(and(lhs, rhs)),
167+
_ => None,
168+
});
169+
}
170+
Operator::Or => {
171+
let lhs_satisfier = ctx.satisfy(lhs)?;
172+
let rhs_satisfier = ctx.satisfy(rhs)?;
173+
return Ok(or_collect(lhs_satisfier.into_iter().chain(rhs_satisfier)));
174+
}
175+
Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => return Ok(None),
176+
};
177+
value_predicate
178+
.map(|value_predicate| {
179+
// Satisfaction must prove more than the values: a NaN
180+
// operand makes every comparison false, and a null operand
181+
// makes it null — neither is `true`, so the rewrite is only
182+
// sound over rows proven non-NaN and non-null.
183+
let guarded = with_nan_predicate(ctx, lhs, rhs, value_predicate)?;
184+
with_all_non_null_predicate(ctx, [lhs, rhs], guarded)
185+
})
186+
.transpose()
187+
}
133188
}
134189

135190
#[derive(Debug)]
@@ -154,6 +209,21 @@ impl StatsRewriteRule for BetweenStatsRewrite {
154209
let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
155210
ctx.falsify(&and(lhs, rhs))
156211
}
212+
213+
fn satisfy(
214+
&self,
215+
expr: &Expression,
216+
ctx: &StatsRewriteCtx<'_>,
217+
) -> VortexResult<Option<Expression>> {
218+
let options = expr.as_::<Between>();
219+
let arr = expr.child(0).clone();
220+
let lower = expr.child(1).clone();
221+
let upper = expr.child(2).clone();
222+
223+
let lhs = Binary.new_expr(options.lower_strict.to_operator(), [lower, arr.clone()]);
224+
let rhs = Binary.new_expr(options.upper_strict.to_operator(), [arr, upper]);
225+
ctx.satisfy(&and(lhs, rhs))
226+
}
157227
}
158228

159229
#[derive(Debug)]
@@ -503,6 +573,27 @@ fn with_nan_predicate(
503573
with_all_non_nan_predicate(ctx, [lhs, rhs], value_predicate)
504574
}
505575

576+
// Satisfaction rewrites prove a predicate true for *every* row, but min/max
577+
// stats describe non-null values only: a null operand row evaluates the
578+
// comparison to null, not true. Guard each nullable operand with an
579+
// all-non-null check; non-nullable operands need none.
580+
fn with_all_non_null_predicate<'a>(
581+
ctx: &StatsRewriteCtx<'_>,
582+
exprs: impl IntoIterator<Item = &'a Expression>,
583+
value_predicate: Expression,
584+
) -> VortexResult<Expression> {
585+
let mut null_checks = Vec::new();
586+
for expr in exprs {
587+
if ctx.return_dtype(expr)?.is_nullable() {
588+
null_checks.push(all_non_null(expr));
589+
}
590+
}
591+
Ok(match and_collect(null_checks) {
592+
Some(null_check) => and(null_check, value_predicate),
593+
None => value_predicate,
594+
})
595+
}
596+
506597
fn with_all_non_nan_predicate<'a>(
507598
ctx: &StatsRewriteCtx<'_>,
508599
exprs: impl IntoIterator<Item = &'a Expression>,
@@ -626,6 +717,7 @@ mod tests {
626717
("f", DType::Primitive(PType::F32, Nullability::NonNullable)),
627718
("s", DType::Utf8(Nullability::NonNullable)),
628719
("t", DType::Utf8(Nullability::NonNullable)),
720+
("n", DType::Primitive(PType::I32, Nullability::Nullable)),
629721
]),
630722
Nullability::NonNullable,
631723
)
@@ -671,6 +763,96 @@ mod tests {
671763
Ok(())
672764
}
673765

766+
#[test]
767+
fn rewrites_comparison_satisfier() -> VortexResult<()> {
768+
// Non-nullable integer: the value condition alone proves all-true.
769+
let expr = lt(col("a"), lit(10));
770+
assert_eq!(
771+
satisfy(&expr)?,
772+
Some(lt(stat(col("a"), Stat::Max), lit(10)))
773+
);
774+
775+
let expr = gt_eq(col("a"), lit(10));
776+
assert_eq!(
777+
satisfy(&expr)?,
778+
Some(gt_eq(stat(col("a"), Stat::Min), lit(10)))
779+
);
780+
781+
// Column-to-column comparison uses both sides' stats.
782+
let expr = lt(col("a"), col("b"));
783+
assert_eq!(
784+
satisfy(&expr)?,
785+
Some(lt(stat(col("a"), Stat::Max), stat(col("b"), Stat::Min)))
786+
);
787+
788+
// Floats must also prove no NaNs: a NaN row never satisfies a
789+
// comparison.
790+
let expr = gt(col("f"), lit(1.0f32));
791+
assert_eq!(
792+
satisfy(&expr)?,
793+
Some(and(
794+
nan_free(col("f")),
795+
gt(stat(col("f"), Stat::Min), lit(1.0f32))
796+
))
797+
);
798+
799+
// Nullable operands must also prove no nulls: a null row evaluates
800+
// the comparison to null, not true.
801+
let expr = lt(col("n"), lit(10));
802+
assert_eq!(
803+
satisfy(&expr)?,
804+
Some(and(
805+
all_non_null(&col("n")),
806+
lt(stat(col("n"), Stat::Max), lit(10))
807+
))
808+
);
809+
Ok(())
810+
}
811+
812+
#[test]
813+
fn rewrites_boolean_satisfiers() -> VortexResult<()> {
814+
// Conjunctions require both satisfiers; disjunctions accept either.
815+
let expr = and(gt(col("a"), lit(10)), lt(col("a"), lit(50)));
816+
assert_eq!(
817+
satisfy(&expr)?,
818+
Some(and(
819+
gt(stat(col("a"), Stat::Min), lit(10)),
820+
lt(stat(col("a"), Stat::Max), lit(50)),
821+
))
822+
);
823+
824+
let expr = or(gt(col("a"), lit(10)), lt(col("a"), lit(0)));
825+
assert_eq!(
826+
satisfy(&expr)?,
827+
Some(or(
828+
gt(stat(col("a"), Stat::Min), lit(10)),
829+
lt(stat(col("a"), Stat::Max), lit(0)),
830+
))
831+
);
832+
Ok(())
833+
}
834+
835+
#[test]
836+
fn rewrites_between_satisfier() -> VortexResult<()> {
837+
let expr = between(
838+
col("a"),
839+
lit(10),
840+
lit(50),
841+
BetweenOptions {
842+
lower_strict: StrictComparison::NonStrict,
843+
upper_strict: StrictComparison::NonStrict,
844+
},
845+
);
846+
assert_eq!(
847+
satisfy(&expr)?,
848+
Some(and(
849+
lt_eq(lit(10), stat(col("a"), Stat::Min)),
850+
lt_eq(stat(col("a"), Stat::Max), lit(50)),
851+
))
852+
);
853+
Ok(())
854+
}
855+
674856
#[test]
675857
fn rewrites_boolean_falsifiers() -> VortexResult<()> {
676858
let expr = and(gt(col("a"), lit(10)), lt(col("a"), lit(50)));

0 commit comments

Comments
 (0)