Skip to content

Commit fd9f67a

Browse files
committed
fix NOT logic in predicate
1 parent 1b7a553 commit fd9f67a

2 files changed

Lines changed: 323 additions & 13 deletions

File tree

src/predicate.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ pub enum ComparisonOp {
6565
GreaterThanOrEqual,
6666
}
6767

68+
impl ComparisonOp {
69+
/// Returns the negated comparison operator.
70+
pub fn negate(&self) -> Self {
71+
match self {
72+
ComparisonOp::Equal => ComparisonOp::NotEqual,
73+
ComparisonOp::NotEqual => ComparisonOp::Equal,
74+
ComparisonOp::LessThan => ComparisonOp::GreaterThanOrEqual,
75+
ComparisonOp::LessThanOrEqual => ComparisonOp::GreaterThan,
76+
ComparisonOp::GreaterThan => ComparisonOp::LessThanOrEqual,
77+
ComparisonOp::GreaterThanOrEqual => ComparisonOp::LessThan,
78+
}
79+
}
80+
}
81+
6882
/// A predicate that can be evaluated against row group statistics
6983
///
7084
/// Predicates are simplified expressions used for filtering row groups before

src/row_group_filter.rs

Lines changed: 309 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,43 @@ fn evaluate_predicate_recursive(
107107
}
108108
}
109109
Predicate::Not(predicate) => {
110-
// For NOT: evaluate predicate, then negate
111-
let mut temp_result = vec![true; result.len()];
112-
evaluate_predicate_recursive(predicate, row_index, schema, &mut temp_result)?;
113-
// NOT logic: result[i] = !temp_result[i]
114-
for (r, t) in result.iter_mut().zip(temp_result.iter()) {
115-
*r = !*t;
110+
match &**predicate {
111+
Predicate::Not(inner) => {
112+
evaluate_predicate_recursive(inner, row_index, schema, result)?;
113+
}
114+
Predicate::IsNull { column } => {
115+
evaluate_is_not_null(column, row_index, schema, result)?;
116+
}
117+
Predicate::IsNotNull { column } => {
118+
evaluate_is_null(column, row_index, schema, result)?;
119+
}
120+
Predicate::Comparison { column, op, value } => {
121+
evaluate_comparison(column, op.negate(), value, row_index, schema, result)?;
122+
}
123+
Predicate::And(predicates) => {
124+
let not_preds: Vec<Predicate> = predicates
125+
.iter()
126+
.map(|p| Predicate::Not(Box::new(p.clone())))
127+
.collect();
128+
evaluate_predicate_recursive(
129+
&Predicate::Or(not_preds),
130+
row_index,
131+
schema,
132+
result,
133+
)?;
134+
}
135+
Predicate::Or(predicates) => {
136+
let not_preds: Vec<Predicate> = predicates
137+
.iter()
138+
.map(|p| Predicate::Not(Box::new(p.clone())))
139+
.collect();
140+
evaluate_predicate_recursive(
141+
&Predicate::And(not_preds),
142+
row_index,
143+
schema,
144+
result,
145+
)?;
146+
}
116147
}
117148
}
118149
}
@@ -1015,26 +1046,291 @@ mod tests {
10151046
}
10161047

10171048
#[test]
1018-
fn test_evaluate_predicate_missing_statistics() {
1049+
fn test_evaluate_predicate_not_is_null() {
1050+
use crate::predicate::Predicate;
1051+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1052+
use std::collections::HashMap;
1053+
1054+
// Create row index with mixed nulls and values
1055+
let mut columns = HashMap::new();
1056+
let entries = vec![
1057+
RowGroupEntry::new(
1058+
Some({
1059+
let proto_stats = proto::ColumnStatistics {
1060+
number_of_values: Some(5000),
1061+
has_null: Some(true),
1062+
int_statistics: Some(proto::IntegerStatistics {
1063+
minimum: Some(18),
1064+
maximum: Some(25),
1065+
sum: Some(107500),
1066+
}),
1067+
..Default::default()
1068+
};
1069+
ColumnStatistics::try_from(&proto_stats).unwrap()
1070+
}),
1071+
vec![],
1072+
),
1073+
];
1074+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1075+
let row_index = StripeRowIndex::new(columns, 10000, 10000);
1076+
let schema = create_test_schema();
1077+
1078+
// Test: Not(age IS NULL) -> age IS NOT NULL
1079+
let predicate = Predicate::not(Predicate::is_null("age"));
1080+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1081+
1082+
assert_eq!(result.len(), 1);
1083+
assert!(result[0]); // Should keep because there are non-null values
1084+
}
1085+
1086+
#[test]
1087+
fn test_evaluate_predicate_not_is_not_null() {
1088+
use crate::predicate::Predicate;
1089+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1090+
use std::collections::HashMap;
1091+
1092+
// Create row index with mixed nulls and values
1093+
let mut columns = HashMap::new();
1094+
let entries = vec![
1095+
// Row group 0: Has nulls (and values)
1096+
RowGroupEntry::new(
1097+
Some({
1098+
let proto_stats = proto::ColumnStatistics {
1099+
number_of_values: Some(5000),
1100+
has_null: Some(true),
1101+
int_statistics: Some(proto::IntegerStatistics {
1102+
minimum: Some(18),
1103+
maximum: Some(25),
1104+
sum: Some(107500),
1105+
}),
1106+
..Default::default()
1107+
};
1108+
ColumnStatistics::try_from(&proto_stats).unwrap()
1109+
}),
1110+
vec![],
1111+
),
1112+
// Row group 1: No nulls
1113+
RowGroupEntry::new(
1114+
Some({
1115+
let proto_stats = proto::ColumnStatistics {
1116+
number_of_values: Some(10000),
1117+
has_null: Some(false),
1118+
int_statistics: Some(proto::IntegerStatistics {
1119+
minimum: Some(26),
1120+
maximum: Some(65),
1121+
sum: Some(455000),
1122+
}),
1123+
..Default::default()
1124+
};
1125+
ColumnStatistics::try_from(&proto_stats).unwrap()
1126+
}),
1127+
vec![],
1128+
),
1129+
];
1130+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1131+
let row_index = StripeRowIndex::new(columns, 20000, 10000);
1132+
let schema = create_test_schema();
1133+
1134+
// Test: Not(age IS NOT NULL) -> age IS NULL
1135+
let predicate = Predicate::not(Predicate::is_not_null("age"));
1136+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1137+
1138+
assert_eq!(result.len(), 2);
1139+
assert!(result[0]); // Row group 0: has_null = true -> Keep
1140+
assert!(!result[1]); // Row group 1: has_null = false -> Skip
1141+
}
1142+
1143+
#[test]
1144+
fn test_evaluate_predicate_not_comparison() {
1145+
use crate::predicate::{Predicate, PredicateValue};
1146+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1147+
use std::collections::HashMap;
1148+
1149+
let mut columns = HashMap::new();
1150+
let entries = vec![RowGroupEntry::new(
1151+
Some({
1152+
let proto_stats = proto::ColumnStatistics {
1153+
number_of_values: Some(10000),
1154+
has_null: Some(false),
1155+
int_statistics: Some(proto::IntegerStatistics {
1156+
minimum: Some(0),
1157+
maximum: Some(10),
1158+
sum: Some(50000),
1159+
}),
1160+
..Default::default()
1161+
};
1162+
ColumnStatistics::try_from(&proto_stats).unwrap()
1163+
}),
1164+
vec![],
1165+
)];
1166+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1167+
let row_index = StripeRowIndex::new(columns, 10000, 10000);
1168+
let schema = create_test_schema();
1169+
1170+
// Test: Not(age > 5) -> age <= 5
1171+
let predicate = Predicate::not(Predicate::gt("age", PredicateValue::Int32(Some(5))));
1172+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1173+
1174+
assert_eq!(result.len(), 1);
1175+
assert!(result[0]);
1176+
}
1177+
1178+
#[test]
1179+
fn test_evaluate_predicate_not_and() {
10191180
use crate::predicate::{Predicate, PredicateValue};
10201181
use crate::row_index::{RowGroupEntry, RowGroupIndex};
10211182
use std::collections::HashMap;
10221183

1023-
// Create row index with missing statistics
10241184
let mut columns = HashMap::new();
10251185
let entries = vec![
1026-
RowGroupEntry::new(None, vec![]), // No statistics
1186+
RowGroupEntry::new(
1187+
Some({
1188+
let proto_stats = proto::ColumnStatistics {
1189+
number_of_values: Some(10000),
1190+
has_null: Some(false),
1191+
int_statistics: Some(proto::IntegerStatistics {
1192+
minimum: Some(0),
1193+
maximum: Some(10),
1194+
sum: Some(50000),
1195+
}),
1196+
..Default::default()
1197+
};
1198+
ColumnStatistics::try_from(&proto_stats).unwrap()
1199+
}),
1200+
vec![],
1201+
),
1202+
RowGroupEntry::new(
1203+
Some({
1204+
let proto_stats = proto::ColumnStatistics {
1205+
number_of_values: Some(10000),
1206+
has_null: Some(false),
1207+
int_statistics: Some(proto::IntegerStatistics {
1208+
minimum: Some(20),
1209+
maximum: Some(30),
1210+
sum: Some(250000),
1211+
}),
1212+
..Default::default()
1213+
};
1214+
ColumnStatistics::try_from(&proto_stats).unwrap()
1215+
}),
1216+
vec![],
1217+
),
10271218
];
10281219
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1220+
let row_index = StripeRowIndex::new(columns, 20000, 10000);
1221+
let schema = create_test_schema();
1222+
1223+
// Test: Not(age >= 15 AND age <= 25)
1224+
// Equivalent to: age < 15 OR age > 25
1225+
// Row Group 1: [0, 10] -> Fits age < 15 -> Keep
1226+
// Row Group 2: [20, 30] -> Fits age > 25 -> Keep
1227+
let predicate = Predicate::not(Predicate::and(vec![
1228+
Predicate::gte("age", PredicateValue::Int32(Some(15))),
1229+
Predicate::lte("age", PredicateValue::Int32(Some(25))),
1230+
]));
1231+
1232+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1233+
1234+
assert_eq!(result.len(), 2);
1235+
assert!(result[0]); // [0, 10] is < 15
1236+
assert!(result[1]); // [20, 30] contains values > 25 (26..30)
1237+
}
1238+
1239+
#[test]
1240+
fn test_evaluate_predicate_not_or() {
1241+
use crate::predicate::{Predicate, PredicateValue};
1242+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1243+
use std::collections::HashMap;
1244+
1245+
let mut columns = HashMap::new();
1246+
let entries = vec![
1247+
RowGroupEntry::new(
1248+
Some({
1249+
let proto_stats = proto::ColumnStatistics {
1250+
number_of_values: Some(10000),
1251+
has_null: Some(false),
1252+
int_statistics: Some(proto::IntegerStatistics {
1253+
minimum: Some(0),
1254+
maximum: Some(5),
1255+
sum: Some(25000),
1256+
}),
1257+
..Default::default()
1258+
};
1259+
ColumnStatistics::try_from(&proto_stats).unwrap()
1260+
}),
1261+
vec![],
1262+
),
1263+
RowGroupEntry::new(
1264+
Some({
1265+
let proto_stats = proto::ColumnStatistics {
1266+
number_of_values: Some(10000),
1267+
has_null: Some(false),
1268+
int_statistics: Some(proto::IntegerStatistics {
1269+
minimum: Some(5),
1270+
maximum: Some(15),
1271+
sum: Some(100000),
1272+
}),
1273+
..Default::default()
1274+
};
1275+
ColumnStatistics::try_from(&proto_stats).unwrap()
1276+
}),
1277+
vec![],
1278+
),
1279+
];
1280+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
1281+
let row_index = StripeRowIndex::new(columns, 20000, 10000);
1282+
let schema = create_test_schema();
1283+
1284+
// Test: Not(age < 10 OR age > 30)
1285+
// Equivalent to: age >= 10 AND age <= 30
1286+
let predicate = Predicate::not(Predicate::or(vec![
1287+
Predicate::lt("age", PredicateValue::Int32(Some(10))),
1288+
Predicate::gt("age", PredicateValue::Int32(Some(30))),
1289+
]));
1290+
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
1291+
1292+
assert_eq!(result.len(), 2);
1293+
assert!(!result[0]); // [0, 5] is outside [10, 30] -> Skip
1294+
assert!(result[1]); // [5, 15] overlaps [10, 30] -> Keep
1295+
}
1296+
1297+
#[test]
1298+
fn test_evaluate_predicate_double_negation() {
1299+
use crate::predicate::{Predicate, PredicateValue};
1300+
use crate::row_index::{RowGroupEntry, RowGroupIndex};
1301+
use std::collections::HashMap;
1302+
1303+
let mut columns = HashMap::new();
1304+
// Row group: [0, 10]
1305+
let entries = vec![RowGroupEntry::new(
1306+
Some({
1307+
let proto_stats = proto::ColumnStatistics {
1308+
number_of_values: Some(10000),
1309+
has_null: Some(false),
1310+
int_statistics: Some(proto::IntegerStatistics {
1311+
minimum: Some(0),
1312+
maximum: Some(10),
1313+
sum: Some(50000),
1314+
}),
1315+
..Default::default()
1316+
};
1317+
ColumnStatistics::try_from(&proto_stats).unwrap()
1318+
}),
1319+
vec![],
1320+
)];
1321+
columns.insert(1, RowGroupIndex::new(entries, 10000, 1));
10291322
let row_index = StripeRowIndex::new(columns, 10000, 10000);
10301323
let schema = create_test_schema();
10311324

1032-
// Test: age > 10
1033-
// Should keep row group when statistics are missing (conservative)
1034-
let predicate = Predicate::gt("age", PredicateValue::Int32(Some(10)));
1325+
// Test: Not(Not(age > 5)) -> age > 5
1326+
// Row group [0, 10] contains values > 5 -> Keep
1327+
let predicate = Predicate::not(Predicate::not(Predicate::gt(
1328+
"age",
1329+
PredicateValue::Int32(Some(5)),
1330+
)));
10351331
let result = super::evaluate_predicate(&predicate, &row_index, &schema).unwrap();
10361332

10371333
assert_eq!(result.len(), 1);
1038-
assert!(result[0]); // Keep when statistics missing
1334+
assert!(result[0]);
10391335
}
10401336
}

0 commit comments

Comments
 (0)