@@ -84,15 +84,15 @@ impl StatsRewriteRule for BinaryStatsRewrite {
8484
8585 Ok ( match operator {
8686 Operator :: Eq => {
87- let left = min ( lhs) . zip ( max ( rhs) ) . map ( |( a, b) | gt ( a, b) ) ;
88- let right = min ( rhs) . zip ( max ( lhs) ) . map ( |( a, b) | gt ( a, b) ) ;
87+ let left = min ( lhs, ctx ) . zip ( max ( rhs, ctx ) ) . map ( |( a, b) | gt ( a, b) ) ;
88+ let right = min ( rhs, ctx ) . zip ( max ( lhs, ctx ) ) . map ( |( a, b) | gt ( a, b) ) ;
8989 or_collect ( left. into_iter ( ) . chain ( right) )
9090 . map ( |value_predicate| with_nan_predicate ( ctx, lhs, rhs, value_predicate) )
9191 . transpose ( ) ?
9292 }
93- Operator :: NotEq => min ( lhs)
94- . zip ( max ( rhs) )
95- . zip ( max ( lhs) . zip ( min ( rhs) ) )
93+ Operator :: NotEq => min ( lhs, ctx )
94+ . zip ( max ( rhs, ctx ) )
95+ . zip ( max ( lhs, ctx ) . zip ( min ( rhs, ctx ) ) )
9696 . map ( |( ( min_lhs, max_rhs) , ( max_lhs, min_rhs) ) | {
9797 with_nan_predicate (
9898 ctx,
@@ -102,20 +102,20 @@ impl StatsRewriteRule for BinaryStatsRewrite {
102102 )
103103 } )
104104 . transpose ( ) ?,
105- Operator :: Gt => max ( lhs)
106- . zip ( min ( rhs) )
105+ Operator :: Gt => max ( lhs, ctx )
106+ . zip ( min ( rhs, ctx ) )
107107 . map ( |( a, b) | with_nan_predicate ( ctx, lhs, rhs, lt_eq ( a, b) ) )
108108 . transpose ( ) ?,
109- Operator :: Gte => max ( lhs)
110- . zip ( min ( rhs) )
109+ Operator :: Gte => max ( lhs, ctx )
110+ . zip ( min ( rhs, ctx ) )
111111 . map ( |( a, b) | with_nan_predicate ( ctx, lhs, rhs, lt ( a, b) ) )
112112 . transpose ( ) ?,
113- Operator :: Lt => min ( lhs)
114- . zip ( max ( rhs) )
113+ Operator :: Lt => min ( lhs, ctx )
114+ . zip ( max ( rhs, ctx ) )
115115 . map ( |( a, b) | with_nan_predicate ( ctx, lhs, rhs, gt_eq ( a, b) ) )
116116 . transpose ( ) ?,
117- Operator :: Lte => min ( lhs)
118- . zip ( max ( rhs) )
117+ Operator :: Lte => min ( lhs, ctx )
118+ . zip ( max ( rhs, ctx ) )
119119 . map ( |( a, b) | with_nan_predicate ( ctx, lhs, rhs, gt ( a, b) ) )
120120 . transpose ( ) ?,
121121 Operator :: And => {
@@ -167,17 +167,17 @@ impl StatsRewriteRule for IsNullLegacyStatsRewrite {
167167 fn falsify (
168168 & self ,
169169 expr : & Expression ,
170- _ctx : & StatsRewriteCtx < ' _ > ,
170+ ctx : & StatsRewriteCtx < ' _ > ,
171171 ) -> VortexResult < Option < Expression > > {
172- Ok ( null_count ( expr. child ( 0 ) ) . map ( |null_count| eq ( null_count, lit ( 0u64 ) ) ) )
172+ Ok ( null_count ( expr. child ( 0 ) , ctx ) . map ( |null_count| eq ( null_count, lit ( 0u64 ) ) ) )
173173 }
174174
175175 fn satisfy (
176176 & self ,
177177 expr : & Expression ,
178- _ctx : & StatsRewriteCtx < ' _ > ,
178+ ctx : & StatsRewriteCtx < ' _ > ,
179179 ) -> VortexResult < Option < Expression > > {
180- Ok ( null_count ( expr. child ( 0 ) )
180+ Ok ( null_count ( expr. child ( 0 ) , ctx )
181181 . map ( |null_count| eq ( null_count, RowCount . new_expr ( EmptyOptions , [ ] ) ) ) )
182182 }
183183}
@@ -227,18 +227,18 @@ impl StatsRewriteRule for IsNotNullLegacyStatsRewrite {
227227 fn falsify (
228228 & self ,
229229 expr : & Expression ,
230- _ctx : & StatsRewriteCtx < ' _ > ,
230+ ctx : & StatsRewriteCtx < ' _ > ,
231231 ) -> VortexResult < Option < Expression > > {
232- Ok ( null_count ( expr. child ( 0 ) )
232+ Ok ( null_count ( expr. child ( 0 ) , ctx )
233233 . map ( |null_count| eq ( null_count, RowCount . new_expr ( EmptyOptions , [ ] ) ) ) )
234234 }
235235
236236 fn satisfy (
237237 & self ,
238238 expr : & Expression ,
239- _ctx : & StatsRewriteCtx < ' _ > ,
239+ ctx : & StatsRewriteCtx < ' _ > ,
240240 ) -> VortexResult < Option < Expression > > {
241- Ok ( null_count ( expr. child ( 0 ) ) . map ( |null_count| eq ( null_count, lit ( 0u64 ) ) ) )
241+ Ok ( null_count ( expr. child ( 0 ) , ctx ) . map ( |null_count| eq ( null_count, lit ( 0u64 ) ) ) )
242242 }
243243}
244244
@@ -287,7 +287,7 @@ impl StatsRewriteRule for LikeStatsRewrite {
287287 fn falsify (
288288 & self ,
289289 expr : & Expression ,
290- _ctx : & StatsRewriteCtx < ' _ > ,
290+ ctx : & StatsRewriteCtx < ' _ > ,
291291 ) -> VortexResult < Option < Expression > > {
292292 let like_options = expr. as_ :: < Like > ( ) ;
293293 if like_options. negated || like_options. case_insensitive {
@@ -304,8 +304,8 @@ impl StatsRewriteRule for LikeStatsRewrite {
304304 let source = expr. child ( 0 ) ;
305305 Ok ( match LikeVariant :: from_str ( pattern) {
306306 Some ( LikeVariant :: Exact ( text) ) => {
307- min ( source)
308- . zip ( max ( source) )
307+ min ( source, ctx )
308+ . zip ( max ( source, ctx ) )
309309 . map ( |( source_min, source_max) | {
310310 or (
311311 gt ( source_min, lit ( text. as_ref ( ) ) ) ,
@@ -317,8 +317,8 @@ impl StatsRewriteRule for LikeStatsRewrite {
317317 let Some ( successor) = prefix. to_string ( ) . increment ( ) . ok ( ) else {
318318 return Ok ( None ) ;
319319 } ;
320- min ( source)
321- . zip ( max ( source) )
320+ min ( source, ctx )
321+ . zip ( max ( source, ctx ) )
322322 . map ( |( source_min, source_max) | {
323323 or (
324324 gt_eq ( source_min, lit ( successor) ) ,
@@ -361,10 +361,10 @@ impl StatsRewriteRule for ListContainsStatsRewrite {
361361 return Ok ( Some ( lit ( true ) ) ) ;
362362 }
363363
364- let Some ( value_max) = max ( needle) else {
364+ let Some ( value_max) = max ( needle, ctx ) else {
365365 return Ok ( None ) ;
366366 } ;
367- let Some ( value_min) = min ( needle) else {
367+ let Some ( value_min) = min ( needle, ctx ) else {
368368 return Ok ( None ) ;
369369 } ;
370370
@@ -398,10 +398,10 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite {
398398
399399 let Some ( ( operator, lhs_stat) ) = ( match dynamic. operator {
400400 CompareOperator :: Eq | CompareOperator :: NotEq => None ,
401- CompareOperator :: Gt => max ( lhs) . map ( |lhs_stat| ( CompareOperator :: Lte , lhs_stat) ) ,
402- CompareOperator :: Gte => max ( lhs) . map ( |lhs_stat| ( CompareOperator :: Lt , lhs_stat) ) ,
403- CompareOperator :: Lt => min ( lhs) . map ( |lhs_stat| ( CompareOperator :: Gte , lhs_stat) ) ,
404- CompareOperator :: Lte => min ( lhs) . map ( |lhs_stat| ( CompareOperator :: Gt , lhs_stat) ) ,
401+ CompareOperator :: Gt => max ( lhs, ctx ) . map ( |lhs_stat| ( CompareOperator :: Lte , lhs_stat) ) ,
402+ CompareOperator :: Gte => max ( lhs, ctx ) . map ( |lhs_stat| ( CompareOperator :: Lt , lhs_stat) ) ,
403+ CompareOperator :: Lt => min ( lhs, ctx ) . map ( |lhs_stat| ( CompareOperator :: Gte , lhs_stat) ) ,
404+ CompareOperator :: Lte => min ( lhs, ctx ) . map ( |lhs_stat| ( CompareOperator :: Gt , lhs_stat) ) ,
405405 } ) else {
406406 return Ok ( None ) ;
407407 } ;
@@ -418,16 +418,16 @@ impl StatsRewriteRule for DynamicComparisonStatsRewrite {
418418 }
419419}
420420
421- fn min ( expr : & Expression ) -> Option < Expression > {
422- stat_expr ( expr, Stat :: Min )
421+ fn min ( expr : & Expression , ctx : & StatsRewriteCtx < ' _ > ) -> Option < Expression > {
422+ stat_expr ( expr, Stat :: Min , ctx )
423423}
424424
425- fn max ( expr : & Expression ) -> Option < Expression > {
426- stat_expr ( expr, Stat :: Max )
425+ fn max ( expr : & Expression , ctx : & StatsRewriteCtx < ' _ > ) -> Option < Expression > {
426+ stat_expr ( expr, Stat :: Max , ctx )
427427}
428428
429- fn null_count ( expr : & Expression ) -> Option < Expression > {
430- stat_expr ( expr, Stat :: NullCount )
429+ fn null_count ( expr : & Expression , ctx : & StatsRewriteCtx < ' _ > ) -> Option < Expression > {
430+ stat_expr ( expr, Stat :: NullCount , ctx )
431431}
432432
433433fn all_null ( expr : & Expression ) -> Expression {
@@ -474,7 +474,7 @@ fn has_nans(dtype: &DType) -> bool {
474474 matches ! ( dtype, DType :: Primitive ( ptype, _) if ptype. is_float( ) )
475475}
476476
477- fn stat_expr ( expr : & Expression , stat : Stat ) -> Option < Expression > {
477+ fn stat_expr ( expr : & Expression , stat : Stat , ctx : & StatsRewriteCtx < ' _ > ) -> Option < Expression > {
478478 if let Some ( literal) = literal_stat ( expr, stat) {
479479 return Some ( literal) ;
480480 }
@@ -487,11 +487,18 @@ fn stat_expr(expr: &Expression, stat: Stat) -> Option<Expression> {
487487 }
488488
489489 if let Some ( dtype) = expr. as_opt :: < Cast > ( ) {
490- return cast_stat ( expr. child ( 0 ) , dtype, stat) ;
491- }
492-
493- stat. aggregate_fn ( )
494- . map ( |aggregate_fn| stat_fn ( expr. clone ( ) , aggregate_fn) )
490+ return cast_stat ( expr. child ( 0 ) , dtype, stat, ctx) ;
491+ }
492+
493+ let aggregate_fn = stat. aggregate_fn ( ) ?;
494+ // The aggregate may not support the expression's dtype, e.g. min/max over structs,
495+ // even when the predicate itself is well-typed. Such stats cannot be lowered later,
496+ // so do not reference them in the rewrite.
497+ let input_dtype = ctx. return_dtype ( expr) . ok ( ) ?;
498+ aggregate_fn
499+ . return_dtype ( & input_dtype)
500+ . is_some ( )
501+ . then ( || stat_fn ( expr. clone ( ) , aggregate_fn) )
495502}
496503
497504fn with_nan_predicate (
@@ -545,10 +552,15 @@ fn literal_stat(expr: &Expression, stat: Stat) -> Option<Expression> {
545552 }
546553}
547554
548- fn cast_stat ( expr : & Expression , dtype : & DType , stat : Stat ) -> Option < Expression > {
555+ fn cast_stat (
556+ expr : & Expression ,
557+ dtype : & DType ,
558+ stat : Stat ,
559+ ctx : & StatsRewriteCtx < ' _ > ,
560+ ) -> Option < Expression > {
549561 match stat {
550- Stat :: Min | Stat :: Max => stat_expr ( expr, stat) . map ( |stat| cast ( stat, dtype. clone ( ) ) ) ,
551- Stat :: NaNCount | Stat :: Sum | Stat :: UncompressedSizeInBytes => stat_expr ( expr, stat) ,
562+ Stat :: Min | Stat :: Max => stat_expr ( expr, stat, ctx ) . map ( |stat| cast ( stat, dtype. clone ( ) ) ) ,
563+ Stat :: NaNCount | Stat :: Sum | Stat :: UncompressedSizeInBytes => stat_expr ( expr, stat, ctx ) ,
552564 Stat :: NullCount | Stat :: IsConstant | Stat :: IsSorted | Stat :: IsStrictSorted => None ,
553565 }
554566}
@@ -626,11 +638,19 @@ mod tests {
626638 ( "f" , DType :: Primitive ( PType :: F32 , Nullability :: NonNullable ) ) ,
627639 ( "s" , DType :: Utf8 ( Nullability :: NonNullable ) ) ,
628640 ( "t" , DType :: Utf8 ( Nullability :: NonNullable ) ) ,
641+ ( "n" , nested_struct_dtype ( ) ) ,
629642 ] ) ,
630643 Nullability :: NonNullable ,
631644 )
632645 }
633646
647+ fn nested_struct_dtype ( ) -> DType {
648+ DType :: Struct (
649+ StructFields :: from_iter ( [ ( "x" , DType :: Primitive ( PType :: F32 , Nullability :: Nullable ) ) ] ) ,
650+ Nullability :: NonNullable ,
651+ )
652+ }
653+
634654 fn falsify ( expr : & Expression ) -> VortexResult < Option < Expression > > {
635655 expr. falsify ( & test_scope ( ) , & SESSION )
636656 }
@@ -848,6 +868,19 @@ mod tests {
848868 Ok ( ( ) )
849869 }
850870
871+ #[ test]
872+ fn skips_falsifier_when_min_max_unsupported_for_dtype ( ) -> VortexResult < ( ) > {
873+ // Struct comparisons are valid predicates, but min/max aggregates do not
874+ // support struct inputs, so no stats-backed falsifier should be produced.
875+ let struct_scalar = Scalar :: struct_ (
876+ nested_struct_dtype ( ) ,
877+ vec ! [ Scalar :: primitive( 1.0f32 , Nullability :: Nullable ) ] ,
878+ ) ;
879+ assert_eq ! ( falsify( & lt_eq( col( "n" ) , lit( struct_scalar. clone( ) ) ) ) ?, None ) ;
880+ assert_eq ! ( falsify( & eq( col( "n" ) , lit( struct_scalar) ) ) ?, None ) ;
881+ Ok ( ( ) )
882+ }
883+
851884 #[ test]
852885 fn forwards_min_max_through_safe_cast ( ) -> VortexResult < ( ) > {
853886 let dtype = DType :: Primitive ( PType :: I64 , Nullability :: NonNullable ) ;
0 commit comments