diff --git a/src/rewrite/normal_form.rs b/src/rewrite/normal_form.rs index 195f8da..70fd129 100644 --- a/src/rewrite/normal_form.rs +++ b/src/rewrite/normal_form.rs @@ -247,32 +247,50 @@ impl SpjNormalForm { /// This is useful for rewriting queries to use materialized views. pub fn rewrite_from( &self, - mut other: &Self, + other: &Self, qualifier: TableReference, source: Arc, ) -> Result> { log::trace!("rewriting from {qualifier}"); + + // Cache columns() result to avoid repeated Vec allocation in the loop. + // DFSchema::columns() creates a new Vec on each call. + let output_columns = self.output_schema.columns(); + let mut new_output_exprs = Vec::with_capacity(self.output_exprs.len()); // check that our output exprs are sub-expressions of the other one's output exprs for (i, output_expr) in self.output_exprs.iter().enumerate() { - let new_output_expr = other - .predicate - .normalize_expr(output_expr.clone()) - .rewrite(&mut other)? - .data; - - // Check that all references to the original tables have been replaced. - // All remaining column expressions should be unqualified, which indicates - // that they refer to the output of the sub-plan (in this case the view) - if new_output_expr - .column_refs() - .iter() - .any(|c| c.relation.is_some()) - { - return Ok(None); - } + // Fast path for simple Column expressions (most common case). + // This avoids the expensive normalize_expr transform for columns. + let new_output_expr = if let Expr::Column(col) = output_expr { + let normalized_col = other.predicate.normalize_column(col); + match other.find_output_column(&normalized_col) { + Some(rewritten) => rewritten, + None => return Ok(None), // Column not found, can't rewrite + } + } else { + // Slow path: complex expressions need full transform + let new_output_expr = other + .predicate + .normalize_expr(output_expr.clone()) + .rewrite(&mut &*other)? + .data; + + // Check that all references to the original tables have been replaced. + // All remaining column expressions should be unqualified, which indicates + // that they refer to the output of the sub-plan (in this case the view) + if new_output_expr + .column_refs() + .iter() + .any(|c| c.relation.is_some()) + { + return Ok(None); + } + new_output_expr + }; - let column = &self.output_schema.columns()[i]; + // Use cached columns instead of calling .columns() on each iteration + let column = &output_columns[i]; new_output_exprs.push( new_output_expr.alias_qualified(column.relation.clone(), column.name.clone()), ); @@ -299,7 +317,7 @@ impl SpjNormalForm { .into_iter() .chain(range_filters) .chain(residual_filters) - .map(|expr| expr.rewrite(&mut other).unwrap().data) + .map(|expr| expr.rewrite(&mut &*other).unwrap().data) .reduce(|a, b| a.and(b)); if all_filters @@ -318,6 +336,20 @@ impl SpjNormalForm { builder.project(new_output_exprs)?.build().map(Some) } + + /// Fast path: find a column in output_exprs and return rewritten expression. + /// This avoids full tree traversal for simple column lookups. + #[inline] + fn find_output_column(&self, col: &Column) -> Option { + self.output_exprs + .iter() + .position(|e| matches!(e, Expr::Column(c) if c == col)) + .map(|idx| { + Expr::Column(Column::new_unqualified( + self.output_schema.field(idx).name().clone(), + )) + }) + } } /// Stores information on filters from a Select-Project-Join plan. @@ -431,6 +463,17 @@ impl Predicate { .and_then(|&idx| self.eq_classes.get(idx)) } + /// Fast path: normalize a single Column without full tree traversal. + /// This is O(1) lookup instead of O(n) transform. + #[inline] + fn normalize_column(&self, col: &Column) -> Column { + if let Some(eq_class) = self.class_for_column(col) { + eq_class.columns.first().unwrap().clone() + } else { + col.clone() + } + } + /// Add a new column equivalence fn add_equivalence(&mut self, c1: &Column, c2: &Column) -> Result<()> { match ( @@ -792,6 +835,15 @@ impl Predicate { /// Rewrite all expressions in terms of their normal representatives /// with respect to this predicate's equivalence classes. fn normalize_expr(&self, e: Expr) -> Expr { + // Fast path: if it's a simple Column, avoid full transform traversal. + // Even though transform() handles Column efficiently, the machinery setup + // (closures, iterators, Transformed wrappers) has overhead that adds up + // when called thousands of times (e.g., 41 columns × 5-7 MVs × every query). + // Direct HashMap lookup + clone is significantly faster. + if let Expr::Column(ref c) = e { + return Expr::Column(self.normalize_column(c)); + } + e.transform(&|e| { let c = match e { Expr::Column(c) => c, @@ -1320,4 +1372,75 @@ mod test { Ok(()) } + + #[tokio::test] + async fn test_normalize_column_fast_path() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.sql("CREATE TABLE t (a INT, b INT, c INT)") + .await? + .collect() + .await?; + + // Query with column equivalence: a = b + let plan = ctx + .sql("SELECT a, b, c FROM t WHERE a = b") + .await? + .into_optimized_plan()?; + + let normal_form = SpjNormalForm::new(&plan)?; + + // Verify that columns are normalized correctly + // a and b should be in the same equivalence class + assert_eq!(normal_form.output_exprs().len(), 3); + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_from_with_many_columns() -> Result<()> { + let ctx = SessionContext::new(); + + // Create a wide table to test the columns() caching optimization + ctx.sql( + "CREATE TABLE wide_table ( + c0 INT, c1 INT, c2 INT, c3 INT, c4 INT, + c5 INT, c6 INT, c7 INT, c8 INT, c9 INT + )", + ) + .await? + .collect() + .await?; + + let base_plan = ctx + .sql("SELECT * FROM wide_table WHERE c0 >= 0") + .await? + .into_optimized_plan()?; + + let query_plan = ctx + .sql("SELECT c0, c1, c2 FROM wide_table WHERE c0 >= 10") + .await? + .into_optimized_plan()?; + + let base_nf = SpjNormalForm::new(&base_plan)?; + let query_nf = SpjNormalForm::new(&query_plan)?; + + // Create MV table + ctx.sql("CREATE TABLE mv AS SELECT * FROM wide_table WHERE c0 >= 0") + .await? + .collect() + .await?; + + let table_ref = TableReference::bare("mv"); + let provider = ctx.table_provider(table_ref.clone()).await?; + + // Test that rewrite_from works correctly with cached columns + let result = query_nf.rewrite_from(&base_nf, table_ref, provider_as_source(provider))?; + + assert!(result.is_some()); + let rewritten = result.unwrap(); + assert_eq!(rewritten.schema().fields().len(), 3); + + Ok(()) + } }