Skip to content

Commit 152c885

Browse files
authored
Support view matcher with boolean binary operation (#117)
1 parent 429e52a commit 152c885

1 file changed

Lines changed: 208 additions & 0 deletions

File tree

src/rewrite/normal_form.rs

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,27 @@ impl Predicate {
656656
fn insert_binary_expr(&mut self, left: &Expr, op: Operator, right: &Expr) -> Result<()> {
657657
match (left, op, right) {
658658
(Expr::Column(c), op, Expr::Literal(v, _)) => {
659+
// Normalize boolean expressions to canonical form:
660+
// col = false -> NOT col
661+
// col != true -> NOT col
662+
// col = true -> col
663+
// col != false -> col
664+
// This ensures semantic equivalence matching (e.g., "active = false" matches "NOT active")
665+
if let ScalarValue::Boolean(Some(b)) = v {
666+
match (op, b) {
667+
(Operator::Eq, false) | (Operator::NotEq, true) => {
668+
self.residuals
669+
.insert(Expr::Not(Box::new(Expr::Column(c.clone()))));
670+
return Ok(());
671+
}
672+
(Operator::Eq, true) | (Operator::NotEq, false) => {
673+
self.residuals.insert(Expr::Column(c.clone()));
674+
return Ok(());
675+
}
676+
_ => {}
677+
}
678+
}
679+
659680
if let Err(e) = self.add_range(c, &op, v) {
660681
// Add a range can fail in some cases, so just fallthrough
661682
log::debug!("failed to add range filter: {e}");
@@ -1443,4 +1464,191 @@ mod test {
14431464

14441465
Ok(())
14451466
}
1467+
1468+
#[tokio::test]
1469+
async fn test_boolean_expression_normalization() -> Result<()> {
1470+
let _ = env_logger::builder().is_test(true).try_init();
1471+
1472+
let ctx = SessionContext::new();
1473+
1474+
// Create table with boolean column
1475+
ctx.sql(
1476+
"CREATE TABLE bool_test (
1477+
id INT,
1478+
active BOOLEAN,
1479+
name VARCHAR
1480+
)",
1481+
)
1482+
.await?
1483+
.collect()
1484+
.await?;
1485+
1486+
ctx.sql("INSERT INTO bool_test VALUES (1, true, 'a'), (2, false, 'b')")
1487+
.await?
1488+
.collect()
1489+
.await?;
1490+
1491+
// MV: uses "active = false"
1492+
let mv_plan = ctx
1493+
.sql("SELECT * FROM bool_test WHERE active = false")
1494+
.await?
1495+
.into_optimized_plan()?;
1496+
let mv_normal_form = SpjNormalForm::new(&mv_plan)?;
1497+
1498+
ctx.sql("CREATE TABLE mv AS SELECT * FROM bool_test WHERE active = false")
1499+
.await?
1500+
.collect()
1501+
.await?;
1502+
1503+
// Query: uses "NOT active" (semantically equivalent to "active = false")
1504+
let query_plan = ctx
1505+
.sql("SELECT id, name FROM bool_test WHERE NOT active")
1506+
.await?
1507+
.into_optimized_plan()?;
1508+
let query_normal_form = SpjNormalForm::new(&query_plan)?;
1509+
1510+
let table_ref = TableReference::bare("mv");
1511+
let rewritten = query_normal_form.rewrite_from(
1512+
&mv_normal_form,
1513+
table_ref.clone(),
1514+
provider_as_source(ctx.table_provider(table_ref).await?),
1515+
)?;
1516+
1517+
assert!(
1518+
rewritten.is_some(),
1519+
"Expected MV with 'active = false' to match query with 'NOT active'"
1520+
);
1521+
1522+
// Also test the reverse: MV with "NOT active", query with "active = false"
1523+
let mv_plan2 = ctx
1524+
.sql("SELECT * FROM bool_test WHERE NOT active")
1525+
.await?
1526+
.into_optimized_plan()?;
1527+
let mv_normal_form2 = SpjNormalForm::new(&mv_plan2)?;
1528+
1529+
ctx.sql("CREATE TABLE mv2 AS SELECT * FROM bool_test WHERE NOT active")
1530+
.await?
1531+
.collect()
1532+
.await?;
1533+
1534+
let query_plan2 = ctx
1535+
.sql("SELECT id FROM bool_test WHERE active = false")
1536+
.await?
1537+
.into_optimized_plan()?;
1538+
let query_normal_form2 = SpjNormalForm::new(&query_plan2)?;
1539+
1540+
let table_ref2 = TableReference::bare("mv2");
1541+
let rewritten2 = query_normal_form2.rewrite_from(
1542+
&mv_normal_form2,
1543+
table_ref2.clone(),
1544+
provider_as_source(ctx.table_provider(table_ref2).await?),
1545+
)?;
1546+
1547+
assert!(
1548+
rewritten2.is_some(),
1549+
"Expected MV with 'NOT active' to match query with 'active = false'"
1550+
);
1551+
1552+
Ok(())
1553+
}
1554+
1555+
#[tokio::test]
1556+
async fn test_boolean_column_normalization() -> Result<()> {
1557+
let _ = env_logger::builder().is_test(true).try_init();
1558+
1559+
let ctx = SessionContext::new();
1560+
1561+
ctx.sql(
1562+
"CREATE TABLE bool_test (
1563+
id INT,
1564+
active BOOLEAN,
1565+
name VARCHAR
1566+
)",
1567+
)
1568+
.await?
1569+
.collect()
1570+
.await?;
1571+
1572+
// Test: MV with "active = false" should match query with "NOT active"
1573+
let mv_plan = ctx
1574+
.sql("SELECT * FROM bool_test WHERE active = false")
1575+
.await?
1576+
.into_optimized_plan()?;
1577+
let mv_normal_form = SpjNormalForm::new(&mv_plan)?;
1578+
1579+
ctx.sql("CREATE TABLE mv AS SELECT * FROM bool_test WHERE active = false")
1580+
.await?
1581+
.collect()
1582+
.await?;
1583+
1584+
let query_plan = ctx
1585+
.sql("SELECT id, name FROM bool_test WHERE NOT active")
1586+
.await?
1587+
.into_optimized_plan()?;
1588+
let query_normal_form = SpjNormalForm::new(&query_plan)?;
1589+
1590+
let table_ref = TableReference::bare("mv");
1591+
let rewritten = query_normal_form.rewrite_from(
1592+
&mv_normal_form,
1593+
table_ref.clone(),
1594+
provider_as_source(ctx.table_provider(table_ref).await?),
1595+
)?;
1596+
1597+
// Should successfully rewrite
1598+
assert!(
1599+
rewritten.is_some(),
1600+
"Expected MV with 'active = false' to match query with 'NOT active'"
1601+
);
1602+
1603+
Ok(())
1604+
}
1605+
1606+
#[tokio::test]
1607+
async fn test_boolean_true_normalization() -> Result<()> {
1608+
let _ = env_logger::builder().is_test(true).try_init();
1609+
1610+
let ctx = SessionContext::new();
1611+
1612+
ctx.sql(
1613+
"CREATE TABLE bool_test2 (
1614+
id INT,
1615+
enabled BOOLEAN
1616+
)",
1617+
)
1618+
.await?
1619+
.collect()
1620+
.await?;
1621+
1622+
// Test: MV with "enabled = true" should match query with just "enabled"
1623+
let mv_plan = ctx
1624+
.sql("SELECT * FROM bool_test2 WHERE enabled = true")
1625+
.await?
1626+
.into_optimized_plan()?;
1627+
let mv_normal_form = SpjNormalForm::new(&mv_plan)?;
1628+
1629+
ctx.sql("CREATE TABLE mv2 AS SELECT * FROM bool_test2 WHERE enabled = true")
1630+
.await?
1631+
.collect()
1632+
.await?;
1633+
1634+
let query_plan = ctx
1635+
.sql("SELECT id FROM bool_test2 WHERE enabled")
1636+
.await?
1637+
.into_optimized_plan()?;
1638+
let query_normal_form = SpjNormalForm::new(&query_plan)?;
1639+
1640+
let table_ref = TableReference::bare("mv2");
1641+
let rewritten = query_normal_form.rewrite_from(
1642+
&mv_normal_form,
1643+
table_ref.clone(),
1644+
provider_as_source(ctx.table_provider(table_ref).await?),
1645+
)?;
1646+
1647+
assert!(
1648+
rewritten.is_some(),
1649+
"Expected MV with 'enabled = true' to match query with 'enabled'"
1650+
);
1651+
1652+
Ok(())
1653+
}
14461654
}

0 commit comments

Comments
 (0)