Skip to content

Commit 5a5a69b

Browse files
committed
fix NOT logic in predicate
1 parent 1b7a553 commit 5a5a69b

2 files changed

Lines changed: 335 additions & 13 deletions

File tree

src/predicate.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ pub enum ComparisonOp {
6565
GreaterThanOrEqual,
6666
}
6767

68+
impl ComparisonOp {
69+
/// Returns the negated comparison operator.
70+
pub fn negate(&self) -> Self {
71+
match self {
72+
ComparisonOp::Equal => ComparisonOp::NotEqual,
73+
ComparisonOp::NotEqual => ComparisonOp::Equal,
74+
ComparisonOp::LessThan => ComparisonOp::GreaterThanOrEqual,
75+
ComparisonOp::LessThanOrEqual => ComparisonOp::GreaterThan,
76+
ComparisonOp::GreaterThan => ComparisonOp::LessThanOrEqual,
77+
ComparisonOp::GreaterThanOrEqual => ComparisonOp::LessThan,
78+
}
79+
}
80+
}
81+
6882
/// A predicate that can be evaluated against row group statistics
6983
///
7084
/// Predicates are simplified expressions used for filtering row groups before

src/row_group_filter.rs

Lines changed: 321 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,45 @@ fn evaluate_predicate_recursive(
107107
}
108108
}
109109
Predicate::Not(predicate) => {
110-
// For NOT: evaluate predicate, then negate
111-
let mut temp_result = vec![true; result.len()];
112-
evaluate_predicate_recursive(predicate, row_index, schema, &mut temp_result)?;
113-
// NOT logic: result[i] = !temp_result[i]
114-
for (r, t) in result.iter_mut().zip(temp_result.iter()) {
115-
*r = !*t;
110+
match &**predicate {
111+
Predicate::Not(inner) => {
112+
evaluate_predicate_recursive(inner, row_index, schema, result)?;
113+
}
114+
Predicate::IsNull { column } => {
115+
evaluate_is_not_null(column, row_index, schema, result)?;
116+
}
117+
Predicate::IsNotNull { column } => {
118+
evaluate_is_null(column, row_index, schema, result)?;
119+
}
120+
Predicate::Comparison { column, op, value } => {
121+
evaluate_comparison(column, op.negate(), value, row_index, schema, result)?;
122+
}
123+
Predicate::And(predicates) => {
124+
// De Morgan's: Not(And(A, B)) -> Or(Not(A), Not(B))
125+
let not_preds: Vec<Predicate> = predicates
126+
.iter()
127+
.map(|p| Predicate::Not(Box::new(p.clone())))
128+
.collect();
129+
evaluate_predicate_recursive(
130+
&Predicate::Or(not_preds),
131+
row_index,
132+
schema,
133+
result,
134+
)?;
135+
}
136+
Predicate::Or(predicates) => {
137+
// De Morgan's: Not(Or(A, B)) -> And(Not(A), Not(B))
138+
let not_preds: Vec<Predicate> = predicates
139+
.iter()
140+
.map(|p| Predicate::Not(Box::new(p.clone())))
141+
.collect();
142+
evaluate_predicate_recursive(
143+
&Predicate::And(not_preds),
144+
row_index,
145+
schema,
146+
result,
147+
)?;
148+
}
116149
}
117150
}
118151
}
@@ -1015,26 +1048,301 @@ mod tests {
10151048
}
10161049

10171050
#[test]
1018-
fn test_evaluate_predicate_missing_statistics() {
1051+
fn test_evaluate_predicate_not_is_null() {
1052+
use crate::predicate::Predicate;
1053+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1054+
use std::collections::HashMap;
1055+
1056+
// Create row index with mixed nulls and values
1057+
let mut columns = HashMap::new();
1058+
let entries = vec![
1059+
RowGroupEntry::new(
1060+
Some({
1061+
let proto_stats = proto::ColumnStatistics {
1062+
number_of_values: Some(5000), // Has non-null values
1063+
has_null: Some(true), // AND has nulls
1064+
int_statistics: Some(proto::IntegerStatistics {
1065+
minimum: Some(18),
1066+
maximum: Some(25),
1067+
sum: Some(107500),
1068+
}),
1069+
..Default::default()
1070+
};
1071+
ColumnStatistics::try_from(&proto_stats).unwrap()
1072+
}),
1073+
vec![],
1074+
),
1075+
];
1076+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1077+
let row_index = StripeRowIndex::new(columns, 10000, 10000);
1078+
let schema = create_test_schema();
1079+
1080+
// Test: Not(age IS NULL) -> age IS NOT NULL
1081+
let predicate = Predicate::not(Predicate::is_null("age"));
1082+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1083+
1084+
assert_eq!(result.len(), 1);
1085+
assert!(result[0]); // Should keep because there are non-null values
1086+
}
1087+
1088+
#[test]
1089+
fn test_evaluate_predicate_not_is_not_null() {
1090+
use crate::predicate::Predicate;
1091+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1092+
use std::collections::HashMap;
1093+
1094+
// Create row index with mixed nulls and values
1095+
let mut columns = HashMap::new();
1096+
let entries = vec![
1097+
// Row group 0: Has nulls (and values)
1098+
RowGroupEntry::new(
1099+
Some({
1100+
let proto_stats = proto::ColumnStatistics {
1101+
number_of_values: Some(5000),
1102+
has_null: Some(true),
1103+
int_statistics: Some(proto::IntegerStatistics {
1104+
minimum: Some(18),
1105+
maximum: Some(25),
1106+
sum: Some(107500),
1107+
}),
1108+
..Default::default()
1109+
};
1110+
ColumnStatistics::try_from(&proto_stats).unwrap()
1111+
}),
1112+
vec![],
1113+
),
1114+
// Row group 1: No nulls
1115+
RowGroupEntry::new(
1116+
Some({
1117+
let proto_stats = proto::ColumnStatistics {
1118+
number_of_values: Some(5000),
1119+
has_null: Some(false),
1120+
int_statistics: Some(proto::IntegerStatistics {
1121+
minimum: Some(26),
1122+
maximum: Some(65),
1123+
sum: Some(227500),
1124+
}),
1125+
..Default::default()
1126+
};
1127+
ColumnStatistics::try_from(&proto_stats).unwrap()
1128+
}),
1129+
vec![],
1130+
),
1131+
];
1132+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1133+
let row_index = StripeRowIndex::new(columns, 20000, 10000);
1134+
let schema = create_test_schema();
1135+
1136+
// Test: Not(age IS NOT NULL) -> age IS NULL
1137+
let predicate = Predicate::not(Predicate::is_not_null("age"));
1138+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1139+
1140+
assert_eq!(result.len(), 2);
1141+
assert!(result[0]); // Row group 0: has_null = true -> Keep
1142+
assert!(!result[1]); // Row group 1: has_null = false -> Skip
1143+
}
1144+
1145+
#[test]
1146+
fn test_evaluate_predicate_not_comparison() {
10191147
use crate::predicate::{Predicate, PredicateValue};
10201148
use crate::row_index::{RowGroupEntry, RowGroupIndex};
10211149
use std::collections::HashMap;
10221150

1023-
// Create row index with missing statistics
1151+
let mut columns = HashMap::new();
1152+
// Row group: [0, 10]
1153+
let entries = vec![RowGroupEntry::new(
1154+
Some({
1155+
let proto_stats = proto::ColumnStatistics {
1156+
number_of_values: Some(1000),
1157+
has_null: Some(false),
1158+
int_statistics: Some(proto::IntegerStatistics {
1159+
minimum: Some(0),
1160+
maximum: Some(10),
1161+
sum: Some(5000),
1162+
}),
1163+
..Default::default()
1164+
};
1165+
ColumnStatistics::try_from(&proto_stats).unwrap()
1166+
}),
1167+
vec![],
1168+
)];
1169+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1170+
let row_index = StripeRowIndex::new(columns, 10000, 10000);
1171+
let schema = create_test_schema();
1172+
1173+
// Test: Not(age > 5) -> age <= 5
1174+
let predicate = Predicate::not(Predicate::gt("age", PredicateValue::Int32(Some(5))));
1175+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1176+
1177+
assert_eq!(result.len(), 1);
1178+
assert!(result[0]);
1179+
}
1180+
1181+
#[test]
1182+
fn test_evaluate_predicate_not_and() {
1183+
use crate::predicate::{Predicate, PredicateValue};
1184+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1185+
use std::collections::HashMap;
1186+
1187+
// Row group 1: [0, 10]
1188+
// Row group 2: [20, 30]
10241189
let mut columns = HashMap::new();
10251190
let entries = vec![
1026-
RowGroupEntry::new(None, vec![]), // No statistics
1191+
RowGroupEntry::new(
1192+
Some({
1193+
let proto_stats = proto::ColumnStatistics {
1194+
number_of_values: Some(1000),
1195+
has_null: Some(false),
1196+
int_statistics: Some(proto::IntegerStatistics {
1197+
minimum: Some(0),
1198+
maximum: Some(10),
1199+
sum: Some(5000),
1200+
}),
1201+
..Default::default()
1202+
};
1203+
ColumnStatistics::try_from(&proto_stats).unwrap()
1204+
}),
1205+
vec![],
1206+
),
1207+
RowGroupEntry::new(
1208+
Some({
1209+
let proto_stats = proto::ColumnStatistics {
1210+
number_of_values: Some(1000),
1211+
has_null: Some(false),
1212+
int_statistics: Some(proto::IntegerStatistics {
1213+
minimum: Some(20),
1214+
maximum: Some(30),
1215+
sum: Some(25000),
1216+
}),
1217+
..Default::default()
1218+
};
1219+
ColumnStatistics::try_from(&proto_stats).unwrap()
1220+
}),
1221+
vec![],
1222+
),
1223+
];
1224+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1225+
let row_index = StripeRowIndex::new(columns, 20000, 10000);
1226+
let schema = create_test_schema();
1227+
1228+
// Test: Not(age >= 15 AND age <= 25)
1229+
// Equivalent to: age < 15 OR age > 25
1230+
// Row Group 1: [0, 10] -> Fits age < 15 -> Keep
1231+
// Row Group 2: [20, 30] -> Fits age > 25 -> Keep
1232+
let predicate = Predicate::not(Predicate::and(vec![
1233+
Predicate::gte("age", PredicateValue::Int32(Some(15))),
1234+
Predicate::lte("age", PredicateValue::Int32(Some(25))),
1235+
]));
1236+
1237+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1238+
1239+
assert_eq!(result.len(), 2);
1240+
assert!(result[0]); // [0, 10] is < 15
1241+
assert!(result[1]); // [20, 30] contains values > 25 (26..30)
1242+
}
1243+
1244+
#[test]
1245+
fn test_evaluate_predicate_not_or() {
1246+
use crate::predicate::{Predicate, PredicateValue};
1247+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1248+
use std::collections::HashMap;
1249+
1250+
let mut columns = HashMap::new();
1251+
let entries = vec![
1252+
// Row group 0: [0, 5]
1253+
// Fits age < 10 fully. Does not overlap [10, 30].
1254+
RowGroupEntry::new(
1255+
Some({
1256+
let proto_stats = proto::ColumnStatistics {
1257+
number_of_values: Some(1000),
1258+
has_null: Some(false),
1259+
int_statistics: Some(proto::IntegerStatistics {
1260+
minimum: Some(0),
1261+
maximum: Some(5),
1262+
sum: Some(2500),
1263+
}),
1264+
..Default::default()
1265+
};
1266+
ColumnStatistics::try_from(&proto_stats).unwrap()
1267+
}),
1268+
vec![],
1269+
),
1270+
// Row group 1: [5, 15]
1271+
// Fits age < 10 (partially).
1272+
// ALSO overlaps [10, 30] (values 10-15).
1273+
// Old logic: evaluate(age < 10) is true -> evaluate(OR) is true -> Not(OR) is false (Skip). WRONG.
1274+
// New logic: Not(OR) -> And(age >= 10, age <= 30). [5, 15] overlaps [10, 30]. Keep.
1275+
RowGroupEntry::new(
1276+
Some({
1277+
let proto_stats = proto::ColumnStatistics {
1278+
number_of_values: Some(1000),
1279+
has_null: Some(false),
1280+
int_statistics: Some(proto::IntegerStatistics {
1281+
minimum: Some(5),
1282+
maximum: Some(15),
1283+
sum: Some(10000),
1284+
}),
1285+
..Default::default()
1286+
};
1287+
ColumnStatistics::try_from(&proto_stats).unwrap()
1288+
}),
1289+
vec![],
1290+
),
10271291
];
10281292
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1293+
let row_index = StripeRowIndex::new(columns, 20000, 10000);
1294+
let schema = create_test_schema();
1295+
1296+
// Test: Not(age < 10 OR age > 30)
1297+
// Equivalent to: age >= 10 AND age <= 30
1298+
let predicate = Predicate::not(Predicate::or(vec![
1299+
Predicate::lt("age", PredicateValue::Int32(Some(10))),
1300+
Predicate::gt("age", PredicateValue::Int32(Some(30))),
1301+
]));
1302+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1303+
1304+
assert_eq!(result.len(), 2);
1305+
assert!(!result[0]); // [0, 5] is outside [10, 30] -> Skip
1306+
assert!(result[1]); // [5, 15] overlaps [10, 30] -> Keep
1307+
}
1308+
1309+
#[test]
1310+
fn test_evaluate_predicate_double_negation() {
1311+
use crate::predicate::{Predicate, PredicateValue};
1312+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1313+
use std::collections::HashMap;
1314+
1315+
let mut columns = HashMap::new();
1316+
// Row group: [0, 10]
1317+
let entries = vec![RowGroupEntry::new(
1318+
Some({
1319+
let proto_stats = proto::ColumnStatistics {
1320+
number_of_values: Some(1000),
1321+
has_null: Some(false),
1322+
int_statistics: Some(proto::IntegerStatistics {
1323+
minimum: Some(0),
1324+
maximum: Some(10),
1325+
sum: Some(5000),
1326+
}),
1327+
..Default::default()
1328+
};
1329+
ColumnStatistics::try_from(&proto_stats).unwrap()
1330+
}),
1331+
vec![],
1332+
)];
1333+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
10291334
let row_index = StripeRowIndex::new(columns, 10000, 10000);
10301335
let schema = create_test_schema();
10311336

1032-
// Test: age > 10
1033-
// Should keep row group when statistics are missing (conservative)
1034-
let predicate = Predicate::gt("age", PredicateValue::Int32(Some(10)));
1337+
// Test: Not(Not(age > 5)) -> age > 5
1338+
// Row group [0, 10] contains values > 5 -> Keep
1339+
let predicate = Predicate::not(Predicate::not(Predicate::gt(
1340+
"age",
1341+
PredicateValue::Int32(Some(5)),
1342+
)));
10351343
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
10361344

10371345
assert_eq!(result.len(), 1);
1038-
assert!(result[0]); // Keep when statistics missing
1346+
assert!(result[0]);
10391347
}
10401348
}

0 commit comments

Comments
 (0)