Skip to content

Commit 0dd63f0

Browse files
authored
Fix generation of stat pruning expression for unsupported dtypes (#8326)
Flagged by the fuzzer and likely doesn't happen in practice but better to be defensive here in case we add complex partials this is alternative version of #8302
1 parent 82b24cd commit 0dd63f0

1 file changed

Lines changed: 80 additions & 47 deletions

File tree

vortex-array/src/stats/rewrite/builtins.rs

Lines changed: 80 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ impl StatsRewriteRule for BinaryStatsRewrite {
8484

8585
Ok(match operator {
8686
Operator::Eq => {
87-
let left = min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b));
88-
let right = min(rhs).zip(max(lhs)).map(|(a, b)| gt(a, b));
87+
let left = min(lhs, ctx).zip(max(rhs, ctx)).map(|(a, b)| gt(a, b));
88+
let right = min(rhs, ctx).zip(max(lhs, ctx)).map(|(a, b)| gt(a, b));
8989
or_collect(left.into_iter().chain(right))
9090
.map(|value_predicate| with_nan_predicate(ctx, lhs, rhs, value_predicate))
9191
.transpose()?
9292
}
93-
Operator::NotEq => min(lhs)
94-
.zip(max(rhs))
95-
.zip(max(lhs).zip(min(rhs)))
93+
Operator::NotEq => min(lhs, ctx)
94+
.zip(max(rhs, ctx))
95+
.zip(max(lhs, ctx).zip(min(rhs, ctx)))
9696
.map(|((min_lhs, max_rhs), (max_lhs, min_rhs))| {
9797
with_nan_predicate(
9898
ctx,
@@ -102,20 +102,20 @@ impl StatsRewriteRule for BinaryStatsRewrite {
102102
)
103103
})
104104
.transpose()?,
105-
Operator::Gt => max(lhs)
106-
.zip(min(rhs))
105+
Operator::Gt => max(lhs, ctx)
106+
.zip(min(rhs, ctx))
107107
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt_eq(a, b)))
108108
.transpose()?,
109-
Operator::Gte => max(lhs)
110-
.zip(min(rhs))
109+
Operator::Gte => max(lhs, ctx)
110+
.zip(min(rhs, ctx))
111111
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, lt(a, b)))
112112
.transpose()?,
113-
Operator::Lt => min(lhs)
114-
.zip(max(rhs))
113+
Operator::Lt => min(lhs, ctx)
114+
.zip(max(rhs, ctx))
115115
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt_eq(a, b)))
116116
.transpose()?,
117-
Operator::Lte => min(lhs)
118-
.zip(max(rhs))
117+
Operator::Lte => min(lhs, ctx)
118+
.zip(max(rhs, ctx))
119119
.map(|(a, b)| with_nan_predicate(ctx, lhs, rhs, gt(a, b)))
120120
.transpose()?,
121121
Operator::And => {
@@ -167,17 +167,17 @@ impl StatsRewriteRule for IsNullLegacyStatsRewrite {
167167
fn falsify(
168168
&self,
169169
expr: &Expression,
170-
_ctx: &StatsRewriteCtx<'_>,
170+
ctx: &StatsRewriteCtx<'_>,
171171
) -> VortexResult<Option<Expression>> {
172-
Ok(null_count(expr.child(0)).map(|null_count| eq(null_count, lit(0u64))))
172+
Ok(null_count(expr.child(0), ctx).map(|null_count| eq(null_count, lit(0u64))))
173173
}
174174

175175
fn satisfy(
176176
&self,
177177
expr: &Expression,
178-
_ctx: &StatsRewriteCtx<'_>,
178+
ctx: &StatsRewriteCtx<'_>,
179179
) -> VortexResult<Option<Expression>> {
180-
Ok(null_count(expr.child(0))
180+
Ok(null_count(expr.child(0), ctx)
181181
.map(|null_count| eq(null_count, RowCount.new_expr(EmptyOptions, []))))
182182
}
183183
}
@@ -227,18 +227,18 @@ impl StatsRewriteRule for IsNotNullLegacyStatsRewrite {
227227
fn falsify(
228228
&self,
229229
expr: &Expression,
230-
_ctx: &StatsRewriteCtx<'_>,
230+
ctx: &StatsRewriteCtx<'_>,
231231
) -> VortexResult<Option<Expression>> {
232-
Ok(null_count(expr.child(0))
232+
Ok(null_count(expr.child(0), ctx)
233233
.map(|null_count| eq(null_count, RowCount.new_expr(EmptyOptions, []))))
234234
}
235235

236236
fn satisfy(
237237
&self,
238238
expr: &Expression,
239-
_ctx: &StatsRewriteCtx<'_>,
239+
ctx: &StatsRewriteCtx<'_>,
240240
) -> VortexResult<Option<Expression>> {
241-
Ok(null_count(expr.child(0)).map(|null_count| eq(null_count, lit(0u64))))
241+
Ok(null_count(expr.child(0), ctx).map(|null_count| eq(null_count, lit(0u64))))
242242
}
243243
}
244244

@@ -287,7 +287,7 @@ impl StatsRewriteRule for LikeStatsRewrite {
287287
fn falsify(
288288
&self,
289289
expr: &Expression,
290-
_ctx: &StatsRewriteCtx<'_>,
290+
ctx: &StatsRewriteCtx<'_>,
291291
) -> VortexResult<Option<Expression>> {
292292
let like_options = expr.as_::<Like>();
293293
if like_options.negated || like_options.case_insensitive {
@@ -304,8 +304,8 @@ impl StatsRewriteRule for LikeStatsRewrite {
304304
let source = expr.child(0);
305305
Ok(match LikeVariant::from_str(pattern) {
306306
Some(LikeVariant::Exact(text)) => {
307-
min(source)
308-
.zip(max(source))
307+
min(source, ctx)
308+
.zip(max(source, ctx))
309309
.map(|(source_min, source_max)| {
310310
or(
311311
gt(source_min, lit(text.as_ref())),
@@ -317,8 +317,8 @@ impl StatsRewriteRule for LikeStatsRewrite {
317317
let Some(successor) = prefix.to_string().increment().ok() else {
318318
return Ok(None);
319319
};
320-
min(source)
321-
.zip(max(source))
320+
min(source, ctx)
321+
.zip(max(source, ctx))
322322
.map(|(source_min, source_max)| {
323323
or(
324324
gt_eq(source_min, lit(successor)),
@@ -361,10 +361,10 @@ impl StatsRewriteRule for ListContainsStatsRewrite {
361361
return Ok(Some(lit(true)));
362362
}
363363

364-
let Some(value_max) = max(needle) else {
364+
let Some(value_max) = max(needle, ctx) else {
365365
return Ok(None);
366366
};
367-
let Some(value_min) = min(needle) else {
367+
let Some(value_min) = min(needle, ctx) else {
368368
return Ok(None);
369369
};
370370

@@ -398,10 +398,10 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite {
398398

399399
let Some((operator, lhs_stat)) = (match dynamic.operator {
400400
CompareOperator::Eq | CompareOperator::NotEq => None,
401-
CompareOperator::Gt => max(lhs).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)),
402-
CompareOperator::Gte => max(lhs).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)),
403-
CompareOperator::Lt => min(lhs).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)),
404-
CompareOperator::Lte => min(lhs).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)),
401+
CompareOperator::Gt => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lte, lhs_stat)),
402+
CompareOperator::Gte => max(lhs, ctx).map(|lhs_stat| (CompareOperator::Lt, lhs_stat)),
403+
CompareOperator::Lt => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gte, lhs_stat)),
404+
CompareOperator::Lte => min(lhs, ctx).map(|lhs_stat| (CompareOperator::Gt, lhs_stat)),
405405
}) else {
406406
return Ok(None);
407407
};
@@ -418,16 +418,16 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite {
418418
}
419419
}
420420

421-
fn min(expr: &Expression) -> Option<Expression> {
422-
stat_expr(expr, Stat::Min)
421+
fn min(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option<Expression> {
422+
stat_expr(expr, Stat::Min, ctx)
423423
}
424424

425-
fn max(expr: &Expression) -> Option<Expression> {
426-
stat_expr(expr, Stat::Max)
425+
fn max(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option<Expression> {
426+
stat_expr(expr, Stat::Max, ctx)
427427
}
428428

429-
fn null_count(expr: &Expression) -> Option<Expression> {
430-
stat_expr(expr, Stat::NullCount)
429+
fn null_count(expr: &Expression, ctx: &StatsRewriteCtx<'_>) -> Option<Expression> {
430+
stat_expr(expr, Stat::NullCount, ctx)
431431
}
432432

433433
fn all_null(expr: &Expression) -> Expression {
@@ -474,7 +474,7 @@ fn has_nans(dtype: &DType) -> bool {
474474
matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float())
475475
}
476476

477-
fn stat_expr(expr: &Expression, stat: Stat) -> Option<Expression> {
477+
fn stat_expr(expr: &Expression, stat: Stat, ctx: &StatsRewriteCtx<'_>) -> Option<Expression> {
478478
if let Some(literal) = literal_stat(expr, stat) {
479479
return Some(literal);
480480
}
@@ -487,11 +487,18 @@ fn stat_expr(expr: &Expression, stat: Stat) -> Option<Expression> {
487487
}
488488

489489
if let Some(dtype) = expr.as_opt::<Cast>() {
490-
return cast_stat(expr.child(0), dtype, stat);
491-
}
492-
493-
stat.aggregate_fn()
494-
.map(|aggregate_fn| stat_fn(expr.clone(), aggregate_fn))
490+
return cast_stat(expr.child(0), dtype, stat, ctx);
491+
}
492+
493+
let aggregate_fn = stat.aggregate_fn()?;
494+
// The aggregate may not support the expression's dtype, e.g. min/max over structs,
495+
// even when the predicate itself is well-typed. Such stats cannot be lowered later,
496+
// so do not reference them in the rewrite.
497+
let input_dtype = ctx.return_dtype(expr).ok()?;
498+
aggregate_fn
499+
.return_dtype(&input_dtype)
500+
.is_some()
501+
.then(|| stat_fn(expr.clone(), aggregate_fn))
495502
}
496503

497504
fn with_nan_predicate(
@@ -545,10 +552,15 @@ fn literal_stat(expr: &Expression, stat: Stat) -> Option<Expression> {
545552
}
546553
}
547554

548-
fn cast_stat(expr: &Expression, dtype: &DType, stat: Stat) -> Option<Expression> {
555+
fn cast_stat(
556+
expr: &Expression,
557+
dtype: &DType,
558+
stat: Stat,
559+
ctx: &StatsRewriteCtx<'_>,
560+
) -> Option<Expression> {
549561
match stat {
550-
Stat::Min | Stat::Max => stat_expr(expr, stat).map(|stat| cast(stat, dtype.clone())),
551-
Stat::NaNCount | Stat::Sum | Stat::UncompressedSizeInBytes => stat_expr(expr, stat),
562+
Stat::Min | Stat::Max => stat_expr(expr, stat, ctx).map(|stat| cast(stat, dtype.clone())),
563+
Stat::NaNCount | Stat::Sum | Stat::UncompressedSizeInBytes => stat_expr(expr, stat, ctx),
552564
Stat::NullCount | Stat::IsConstant | Stat::IsSorted | Stat::IsStrictSorted => None,
553565
}
554566
}
@@ -626,11 +638,19 @@ mod tests {
626638
("f", DType::Primitive(PType::F32, Nullability::NonNullable)),
627639
("s", DType::Utf8(Nullability::NonNullable)),
628640
("t", DType::Utf8(Nullability::NonNullable)),
641+
("n", nested_struct_dtype()),
629642
]),
630643
Nullability::NonNullable,
631644
)
632645
}
633646

647+
fn nested_struct_dtype() -> DType {
648+
DType::Struct(
649+
StructFields::from_iter([("x", DType::Primitive(PType::F32, Nullability::Nullable))]),
650+
Nullability::NonNullable,
651+
)
652+
}
653+
634654
fn falsify(expr: &Expression) -> VortexResult<Option<Expression>> {
635655
expr.falsify(&test_scope(), &SESSION)
636656
}
@@ -848,6 +868,19 @@ mod tests {
848868
Ok(())
849869
}
850870

871+
#[test]
872+
fn skips_falsifier_when_min_max_unsupported_for_dtype() -> VortexResult<()> {
873+
// Struct comparisons are valid predicates, but min/max aggregates do not
874+
// support struct inputs, so no stats-backed falsifier should be produced.
875+
let struct_scalar = Scalar::struct_(
876+
nested_struct_dtype(),
877+
vec![Scalar::primitive(1.0f32, Nullability::Nullable)],
878+
);
879+
assert_eq!(falsify(&lt_eq(col("n"), lit(struct_scalar.clone())))?, None);
880+
assert_eq!(falsify(&eq(col("n"), lit(struct_scalar)))?, None);
881+
Ok(())
882+
}
883+
851884
#[test]
852885
fn forwards_min_max_through_safe_cast() -> VortexResult<()> {
853886
let dtype = DType::Primitive(PType::I64, Nullability::NonNullable);

0 commit comments

Comments
 (0)