diff --git a/src/predicate.rs b/src/predicate.rs index 1d317ba..5e99eb7 100644 --- a/src/predicate.rs +++ b/src/predicate.rs @@ -65,6 +65,20 @@ pub enum ComparisonOp { GreaterThanOrEqual, } +impl ComparisonOp { + /// Returns the negated comparison operator. + pub fn negate(&self) -> Self { + match self { + ComparisonOp::Equal => ComparisonOp::NotEqual, + ComparisonOp::NotEqual => ComparisonOp::Equal, + ComparisonOp::LessThan => ComparisonOp::GreaterThanOrEqual, + ComparisonOp::LessThanOrEqual => ComparisonOp::GreaterThan, + ComparisonOp::GreaterThan => ComparisonOp::LessThanOrEqual, + ComparisonOp::GreaterThanOrEqual => ComparisonOp::LessThan, + } + } +} + /// A predicate that can be evaluated against row group statistics /// /// Predicates are simplified expressions used for filtering row groups before diff --git a/src/row_group_filter.rs b/src/row_group_filter.rs index 5672965..28b1531 100644 --- a/src/row_group_filter.rs +++ b/src/row_group_filter.rs @@ -106,15 +106,39 @@ fn evaluate_predicate_recursive( result[i] = temp_results.iter().any(|tr| tr[i]); } } - Predicate::Not(predicate) => { - // For NOT: evaluate predicate, then negate - let mut temp_result = vec![true; result.len()]; - evaluate_predicate_recursive(predicate, row_index, schema, &mut temp_result)?; - // NOT logic: result[i] = !temp_result[i] - for (r, t) in result.iter_mut().zip(temp_result.iter()) { - *r = !*t; + Predicate::Not(predicate) => match &**predicate { + Predicate::Not(inner) => { + evaluate_predicate_recursive(inner, row_index, schema, result)?; } - } + Predicate::IsNull { column } => { + evaluate_is_not_null(column, row_index, schema, result)?; + } + Predicate::IsNotNull { column } => { + evaluate_is_null(column, row_index, schema, result)?; + } + Predicate::Comparison { column, op, value } => { + evaluate_comparison(column, op.negate(), value, row_index, schema, result)?; + } + Predicate::And(predicates) => { + let not_preds: Vec = predicates + .iter() + .map(|p| Predicate::Not(Box::new(p.clone()))) + .collect(); + evaluate_predicate_recursive(&Predicate::Or(not_preds), row_index, schema, result)?; + } + Predicate::Or(predicates) => { + let not_preds: Vec = predicates + .iter() + .map(|p| Predicate::Not(Box::new(p.clone()))) + .collect(); + evaluate_predicate_recursive( + &Predicate::And(not_preds), + row_index, + schema, + result, + )?; + } + }, } Ok(()) @@ -1015,26 +1039,289 @@ mod tests { } #[test] - fn test_evaluate_predicate_missing_statistics() { + fn test_evaluate_predicate_not_is_null() { + use crate::predicate::Predicate; + use crate::row_index::{RowGroupEntry, RowGroupIndex}; + use std::collections::HashMap; + + // Create row index with mixed nulls and values + let mut columns = HashMap::new(); + let entries = vec![RowGroupEntry::new( + Some({ + let proto_stats = proto::ColumnStatistics { + number_of_values: Some(5000), + has_null: Some(true), + int_statistics: Some(proto::IntegerStatistics { + minimum: Some(18), + maximum: Some(25), + sum: Some(107500), + }), + ..Default::default() + }; + ColumnStatistics::try_from(&proto_stats).unwrap() + }), + vec![], + )]; + columns.insert(1, RowGroupIndex::new(entries, 10000, 1)); + let row_index = StripeRowIndex::new(columns, 10000, 10000); + let schema = create_test_schema(); + + // Test: Not(age IS NULL) -> age IS NOT NULL + let predicate = Predicate::not(Predicate::is_null("age")); + let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap(); + + assert_eq!(result.len(), 1); + assert!(result[0]); // Should keep because there are non-null values + } + + #[test] + fn test_evaluate_predicate_not_is_not_null() { + use crate::predicate::Predicate; + use crate::row_index::{RowGroupEntry, RowGroupIndex}; + use std::collections::HashMap; + + // Create row index with mixed nulls and values + let mut columns = HashMap::new(); + let entries = vec![ + // Row group 0: Has nulls (and values) + RowGroupEntry::new( + Some({ + let proto_stats = proto::ColumnStatistics { + number_of_values: Some(5000), + has_null: Some(true), + int_statistics: Some(proto::IntegerStatistics { + minimum: Some(18), + maximum: Some(25), + sum: Some(107500), + }), + ..Default::default() + }; + ColumnStatistics::try_from(&proto_stats).unwrap() + }), + vec![], + ), + // Row group 1: No nulls + RowGroupEntry::new( + Some({ + let proto_stats = proto::ColumnStatistics { + number_of_values: Some(10000), + has_null: Some(false), + int_statistics: Some(proto::IntegerStatistics { + minimum: Some(26), + maximum: Some(65), + sum: Some(455000), + }), + ..Default::default() + }; + ColumnStatistics::try_from(&proto_stats).unwrap() + }), + vec![], + ), + ]; + columns.insert(1, RowGroupIndex::new(entries, 10000, 1)); + let row_index = StripeRowIndex::new(columns, 20000, 10000); + let schema = create_test_schema(); + + // Test: Not(age IS NOT NULL) -> age IS NULL + let predicate = Predicate::not(Predicate::is_not_null("age")); + let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap(); + + assert_eq!(result.len(), 2); + assert!(result[0]); // Row group 0: has_null = true -> Keep + assert!(!result[1]); // Row group 1: has_null = false -> Skip + } + + #[test] + fn test_evaluate_predicate_not_comparison() { + use crate::predicate::{Predicate, PredicateValue}; + use crate::row_index::{RowGroupEntry, RowGroupIndex}; + use std::collections::HashMap; + + let mut columns = HashMap::new(); + let entries = vec![RowGroupEntry::new( + Some({ + let proto_stats = proto::ColumnStatistics { + number_of_values: Some(10000), + has_null: Some(false), + int_statistics: Some(proto::IntegerStatistics { + minimum: Some(0), + maximum: Some(10), + sum: Some(50000), + }), + ..Default::default() + }; + ColumnStatistics::try_from(&proto_stats).unwrap() + }), + vec![], + )]; + columns.insert(1, RowGroupIndex::new(entries, 10000, 1)); + let row_index = StripeRowIndex::new(columns, 10000, 10000); + let schema = create_test_schema(); + + // Test: Not(age > 5) -> age <= 5 + let predicate = Predicate::not(Predicate::gt("age", PredicateValue::Int32(Some(5)))); + let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap(); + + assert_eq!(result.len(), 1); + assert!(result[0]); + } + + #[test] + fn test_evaluate_predicate_not_and() { use crate::predicate::{Predicate, PredicateValue}; use crate::row_index::{RowGroupEntry, RowGroupIndex}; use std::collections::HashMap; - // Create row index with missing statistics let mut columns = HashMap::new(); let entries = vec![ - RowGroupEntry::new(None, vec![]), // No statistics + RowGroupEntry::new( + Some({ + let proto_stats = proto::ColumnStatistics { + number_of_values: Some(10000), + has_null: Some(false), + int_statistics: Some(proto::IntegerStatistics { + minimum: Some(0), + maximum: Some(10), + sum: Some(50000), + }), + ..Default::default() + }; + ColumnStatistics::try_from(&proto_stats).unwrap() + }), + vec![], + ), + RowGroupEntry::new( + Some({ + let proto_stats = proto::ColumnStatistics { + number_of_values: Some(10000), + has_null: Some(false), + int_statistics: Some(proto::IntegerStatistics { + minimum: Some(20), + maximum: Some(30), + sum: Some(250000), + }), + ..Default::default() + }; + ColumnStatistics::try_from(&proto_stats).unwrap() + }), + vec![], + ), + ]; + columns.insert(1, RowGroupIndex::new(entries, 10000, 1)); + let row_index = StripeRowIndex::new(columns, 20000, 10000); + let schema = create_test_schema(); + + // Test: Not(age >= 15 AND age <= 25) + // Equivalent to: age < 15 OR age > 25 + // Row Group 1: [0, 10] -> Fits age < 15 -> Keep + // Row Group 2: [20, 30] -> Fits age > 25 -> Keep + let predicate = Predicate::not(Predicate::and(vec![ + Predicate::gte("age", PredicateValue::Int32(Some(15))), + Predicate::lte("age", PredicateValue::Int32(Some(25))), + ])); + + let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap(); + + assert_eq!(result.len(), 2); + assert!(result[0]); // [0, 10] is < 15 + assert!(result[1]); // [20, 30] contains values > 25 (26..30) + } + + #[test] + fn test_evaluate_predicate_not_or() { + use crate::predicate::{Predicate, PredicateValue}; + use crate::row_index::{RowGroupEntry, RowGroupIndex}; + use std::collections::HashMap; + + let mut columns = HashMap::new(); + let entries = vec![ + RowGroupEntry::new( + Some({ + let proto_stats = proto::ColumnStatistics { + number_of_values: Some(10000), + has_null: Some(false), + int_statistics: Some(proto::IntegerStatistics { + minimum: Some(0), + maximum: Some(5), + sum: Some(25000), + }), + ..Default::default() + }; + ColumnStatistics::try_from(&proto_stats).unwrap() + }), + vec![], + ), + RowGroupEntry::new( + Some({ + let proto_stats = proto::ColumnStatistics { + number_of_values: Some(10000), + has_null: Some(false), + int_statistics: Some(proto::IntegerStatistics { + minimum: Some(5), + maximum: Some(15), + sum: Some(100000), + }), + ..Default::default() + }; + ColumnStatistics::try_from(&proto_stats).unwrap() + }), + vec![], + ), ]; columns.insert(1, RowGroupIndex::new(entries, 10000, 1)); + let row_index = StripeRowIndex::new(columns, 20000, 10000); + let schema = create_test_schema(); + + // Test: Not(age < 10 OR age > 30) + // Equivalent to: age >= 10 AND age <= 30 + let predicate = Predicate::not(Predicate::or(vec![ + Predicate::lt("age", PredicateValue::Int32(Some(10))), + Predicate::gt("age", PredicateValue::Int32(Some(30))), + ])); + let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap(); + + assert_eq!(result.len(), 2); + assert!(!result[0]); // [0, 5] is outside [10, 30] -> Skip + assert!(result[1]); // [5, 15] overlaps [10, 30] -> Keep + } + + #[test] + fn test_evaluate_predicate_double_negation() { + use crate::predicate::{Predicate, PredicateValue}; + use crate::row_index::{RowGroupEntry, RowGroupIndex}; + use std::collections::HashMap; + + let mut columns = HashMap::new(); + // Row group: [0, 10] + let entries = vec![RowGroupEntry::new( + Some({ + let proto_stats = proto::ColumnStatistics { + number_of_values: Some(10000), + has_null: Some(false), + int_statistics: Some(proto::IntegerStatistics { + minimum: Some(0), + maximum: Some(10), + sum: Some(50000), + }), + ..Default::default() + }; + ColumnStatistics::try_from(&proto_stats).unwrap() + }), + vec![], + )]; + columns.insert(1, RowGroupIndex::new(entries, 10000, 1)); let row_index = StripeRowIndex::new(columns, 10000, 10000); let schema = create_test_schema(); - // Test: age > 10 - // Should keep row group when statistics are missing (conservative) - let predicate = Predicate::gt("age", PredicateValue::Int32(Some(10))); + // Test: Not(Not(age > 5)) -> age > 5 + // Row group [0, 10] contains values > 5 -> Keep + let predicate = Predicate::not(Predicate::not(Predicate::gt( + "age", + PredicateValue::Int32(Some(5)), + ))); let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap(); assert_eq!(result.len(), 1); - assert!(result[0]); // Keep when statistics missing + assert!(result[0]); } }