Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 142 additions & 19 deletions src/rewrite/normal_form.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,32 +247,50 @@ impl SpjNormalForm {
/// This is useful for rewriting queries to use materialized views.
pub fn rewrite_from(
&self,
mut other: &Self,
other: &Self,
qualifier: TableReference,
source: Arc<dyn TableSource>,
) -> Result<Option<LogicalPlan>> {
log::trace!("rewriting from {qualifier}");

// Cache columns() result to avoid repeated Vec allocation in the loop.
// DFSchema::columns() creates a new Vec on each call.
let output_columns = self.output_schema.columns();

let mut new_output_exprs = Vec::with_capacity(self.output_exprs.len());
// check that our output exprs are sub-expressions of the other one's output exprs
for (i, output_expr) in self.output_exprs.iter().enumerate() {
let new_output_expr = other
.predicate
.normalize_expr(output_expr.clone())
.rewrite(&mut other)?
.data;

// Check that all references to the original tables have been replaced.
// All remaining column expressions should be unqualified, which indicates
// that they refer to the output of the sub-plan (in this case the view)
if new_output_expr
.column_refs()
.iter()
.any(|c| c.relation.is_some())
{
return Ok(None);
}
// Fast path for simple Column expressions (most common case).
// This avoids the expensive normalize_expr transform for columns.
let new_output_expr = if let Expr::Column(col) = output_expr {
let normalized_col = other.predicate.normalize_column(col);
match other.find_output_column(&normalized_col) {
Some(rewritten) => rewritten,
None => return Ok(None), // Column not found, can't rewrite
}
} else {
// Slow path: complex expressions need full transform
let new_output_expr = other
.predicate
.normalize_expr(output_expr.clone())
.rewrite(&mut &*other)?
.data;

// Check that all references to the original tables have been replaced.
// All remaining column expressions should be unqualified, which indicates
// that they refer to the output of the sub-plan (in this case the view)
if new_output_expr
.column_refs()
.iter()
.any(|c| c.relation.is_some())
{
return Ok(None);
}
new_output_expr
};

let column = &self.output_schema.columns()[i];
// Use cached columns instead of calling .columns() on each iteration
let column = &output_columns[i];
new_output_exprs.push(
new_output_expr.alias_qualified(column.relation.clone(), column.name.clone()),
);
Expand All @@ -299,7 +317,7 @@ impl SpjNormalForm {
.into_iter()
.chain(range_filters)
.chain(residual_filters)
.map(|expr| expr.rewrite(&mut other).unwrap().data)
.map(|expr| expr.rewrite(&mut &*other).unwrap().data)
.reduce(|a, b| a.and(b));

if all_filters
Expand All @@ -318,6 +336,20 @@ impl SpjNormalForm {

builder.project(new_output_exprs)?.build().map(Some)
}

/// Fast path: find a column in output_exprs and return rewritten expression.
/// This avoids full tree traversal for simple column lookups.
#[inline]
fn find_output_column(&self, col: &Column) -> Option<Expr> {
self.output_exprs
.iter()
.position(|e| matches!(e, Expr::Column(c) if c == col))
.map(|idx| {
Expr::Column(Column::new_unqualified(
self.output_schema.field(idx).name().clone(),
))
})
}
}

/// Stores information on filters from a Select-Project-Join plan.
Expand Down Expand Up @@ -431,6 +463,17 @@ impl Predicate {
.and_then(|&idx| self.eq_classes.get(idx))
}

/// Fast path: normalize a single Column without full tree traversal.
/// This is O(1) lookup instead of O(n) transform.
#[inline]
fn normalize_column(&self, col: &Column) -> Column {
if let Some(eq_class) = self.class_for_column(col) {
eq_class.columns.first().unwrap().clone()
} else {
col.clone()
}
}

/// Add a new column equivalence
fn add_equivalence(&mut self, c1: &Column, c2: &Column) -> Result<()> {
match (
Expand Down Expand Up @@ -792,6 +835,11 @@ impl Predicate {
/// Rewrite all expressions in terms of their normal representatives
/// with respect to this predicate's equivalence classes.
fn normalize_expr(&self, e: Expr) -> Expr {
// Fast path: if it's a simple Column, avoid full transform traversal
if let Expr::Column(ref c) = e {
return Expr::Column(self.normalize_column(c));
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's a simple column, even if it goes through the transform, it'll be fast imo, I'm curious why this reduces cost

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xudong963 for revivew,

Good question! The overhead isn't from the transform logic itself, but from the transform() machinery setup - it creates closures, iterators, and Transformed wrapper objects even for leaf nodes.
For 41 columns × 5-7 MVs × every query, these small costs add up. The fast path is just a HashMap lookup + clone.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added the comments to PR in latest commit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And i believe the improvement is coming more from another improvement which cached the Repeated columns() calls.


e.transform(&|e| {
let c = match e {
Expr::Column(c) => c,
Expand Down Expand Up @@ -1320,4 +1368,79 @@ mod test {

Ok(())
}

#[tokio::test]
async fn test_normalize_column_fast_path() -> Result<()> {
let ctx = SessionContext::new();

ctx.sql("CREATE TABLE t (a INT, b INT, c INT)")
.await?
.collect()
.await?;

// Query with column equivalence: a = b
let plan = ctx
.sql("SELECT a, b, c FROM t WHERE a = b")
.await?
.into_optimized_plan()?;

let normal_form = SpjNormalForm::new(&plan)?;

// Verify that columns are normalized correctly
// a and b should be in the same equivalence class
assert_eq!(normal_form.output_exprs().len(), 3);

Comment thread
zhuqi-lucas marked this conversation as resolved.
Ok(())
}

#[tokio::test]
async fn test_rewrite_from_with_many_columns() -> Result<()> {
let ctx = SessionContext::new();

// Create a wide table to test the columns() caching optimization
ctx.sql(
"CREATE TABLE wide_table (
c0 INT, c1 INT, c2 INT, c3 INT, c4 INT,
c5 INT, c6 INT, c7 INT, c8 INT, c9 INT
)",
)
.await?
.collect()
.await?;

let base_plan = ctx
.sql("SELECT * FROM wide_table WHERE c0 >= 0")
.await?
.into_optimized_plan()?;

let query_plan = ctx
.sql("SELECT c0, c1, c2 FROM wide_table WHERE c0 >= 10")
.await?
.into_optimized_plan()?;

let base_nf = SpjNormalForm::new(&base_plan)?;
let query_nf = SpjNormalForm::new(&query_plan)?;

// Create MV table
ctx.sql("CREATE TABLE mv AS SELECT * FROM wide_table WHERE c0 >= 0")
.await?
.collect()
.await?;

let table_ref = TableReference::bare("mv");
let provider = ctx.table_provider(table_ref.clone()).await?;

// Test that rewrite_from works correctly with cached columns
let result = query_nf.rewrite_from(
&base_nf,
table_ref,
provider_as_source(provider),
)?;

assert!(result.is_some());
let rewritten = result.unwrap();
assert_eq!(rewritten.schema().fields().len(), 3);

Ok(())
}
}
Loading