diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index ac326be9cb04..2bc1fde764dd 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -190,7 +190,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { /// default implementation will not be called (left as `todo!()`) fn simplify(&self) -> Option { let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { - Ok(Expr::WindowFunction(WindowFunction { + Ok(Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(avg_udaf()), args: window_function.args, partition_by: window_function.partition_by, diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 47dee391c751..935d1776a10c 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -564,19 +564,17 @@ impl DefaultPhysicalPlanner { let input_exec = children.one()?; let get_sort_keys = |expr: &Expr| match expr { - Expr::WindowFunction(WindowFunction { - ref partition_by, - ref order_by, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_function) => generate_sort_key( + window_function.partition_by(), + window_function.order_by(), + ), Expr::Alias(Alias { expr, .. }) => { // Convert &Box to &T match &**expr { - Expr::WindowFunction(WindowFunction { - ref partition_by, - ref order_by, - .. - }) => generate_sort_key(partition_by, order_by), + Expr::WindowFunction(window_function) => generate_sort_key( + window_function.partition_by(), + window_function.order_by(), + ), _ => unreachable!(), } } @@ -1503,14 +1501,15 @@ pub fn create_window_expr_with_name( let name = name.into(); let physical_schema: &Schema = &logical_schema.into(); match e { - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction(window_function) => { + let WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = window_function.as_ref(); let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 8155fd6a2ff9..7903a10b9f62 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -886,7 +886,7 @@ async fn window_using_aggregates() -> Result<()> { vec![col("c3")], ); - Expr::WindowFunction(w) + Expr::from(w) .null_treatment(NullTreatment::IgnoreNulls) .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) .window_frame(WindowFrame::new_bounds( @@ -2550,7 +2550,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { let df_results = ctx .table("t1") .await? - .select(vec![Expr::WindowFunction(WindowFunction::new( + .select(vec![Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 305519a1f4b4..8c0e27a5bf91 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -29,6 +29,7 @@ use crate::utils::expr_to_columns; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; +use crate::function::WindowFunctionSimplification; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ @@ -297,7 +298,7 @@ pub enum Expr { /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. - WindowFunction(WindowFunction), + WindowFunction(Box), // Boxed as it is large (>= 272 bytes) /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -341,6 +342,13 @@ impl From for Expr { } } +/// Create an [`Expr`] from a [`WindowFunction`] +impl From for Expr { + fn from(value: WindowFunction) -> Self { + Expr::WindowFunction(Box::new(value)) + } +} + /// Create an [`Expr`] from an optional qualifier and a [`FieldRef`]. This is /// useful for creating [`Expr`] from a [`DFSchema`]. /// @@ -774,6 +782,16 @@ impl WindowFunctionDefinition { WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), } } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + match self { + WindowFunctionDefinition::AggregateUDF(_) => None, + WindowFunctionDefinition::WindowUDF(udwf) => udwf.simplify(), + } + } } impl Display for WindowFunctionDefinition { @@ -838,6 +856,23 @@ impl WindowFunction { null_treatment: None, } } + + /// return the partition by expressions + pub fn partition_by(&self) -> &Vec { + &self.partition_by + } + + /// return the order by expressions + pub fn order_by(&self) -> &Vec { + &self.order_by + } + + /// Return the the inner window simplification function, if any + /// + /// See [`WindowFunctionSimplification`] for more information + pub fn simplify(&self) -> Option { + self.fun.simplify() + } } /// EXISTS expression @@ -1907,24 +1942,24 @@ impl NormalizeEq for Expr { _ => false, } } - ( - Expr::WindowFunction(WindowFunction { + (Expr::WindowFunction(left), Expr::WindowFunction(right)) => { + let WindowFunction { fun: self_fun, args: self_args, partition_by: self_partition_by, order_by: self_order_by, window_frame: self_window_frame, null_treatment: self_null_treatment, - }), - Expr::WindowFunction(WindowFunction { + } = left.as_ref(); + let WindowFunction { fun: other_fun, args: other_args, partition_by: other_partition_by, order_by: other_order_by, window_frame: other_window_frame, null_treatment: other_null_treatment, - }), - ) => { + } = right.as_ref(); + self_fun.name() == other_fun.name() && self_window_frame == other_window_frame && self_null_treatment == other_null_treatment @@ -2164,14 +2199,15 @@ impl HashNode for Expr { distinct.hash(state); null_treatment.hash(state); } - Expr::WindowFunction(WindowFunction { - fun, - args: _args, - partition_by: _partition_by, - order_by: _order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction(window_func) => { + let WindowFunction { + fun, + args: _args, + partition_by: _partition_by, + order_by: _order_by, + window_frame, + null_treatment, + } = window_func.as_ref(); fun.hash(state); window_frame.hash(state); null_treatment.hash(state); @@ -2472,14 +2508,15 @@ impl Display for SchemaDisplay<'_> { Ok(()) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction(window_func) => { + let WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = window_func.as_ref(); write!( f, "{}({})", @@ -2626,14 +2663,16 @@ impl Display for Expr { // Expr::ScalarFunction(ScalarFunction { func, args }) => { // write!(f, "{}", func.display_name(args).unwrap()) // } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction(window_func) => { + let WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = window_func.as_ref(); + fmt_function(f, &fun.to_string(), false, args, true)?; if let Some(nt) = null_treatment { @@ -3081,4 +3120,19 @@ mod test { rename: opt_rename, } } + + #[test] + fn test_size_of_expr() { + // because Expr is such a widely used struct in DataFusion + // it is important to keep its size as small as possible + // + // If this test fails when you change `Expr`, please try + // `Box`ing the fields to make `Expr` smaller + // See https://github.com/apache/datafusion/issues/14256 for details + assert_eq!(size_of::(), 112); + assert_eq!(size_of::(), 64); + assert_eq!(size_of::(), 24); // 3 ptrs + assert_eq!(size_of::>(), 24); + assert_eq!(size_of::>(), 8); + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a2de5e7b259f..edd42cb513ca 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -843,7 +843,7 @@ impl ExprFuncBuilder { udwf.window_frame = window_frame.unwrap_or(WindowFrame::new(has_order_by)); udwf.null_treatment = null_treatment; - Expr::WindowFunction(udwf) + Expr::from(udwf) } }; @@ -897,7 +897,7 @@ impl ExprFunctionExt for Expr { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))) } _ => ExprFuncBuilder::new(None), }; @@ -937,7 +937,7 @@ impl ExprFunctionExt for Expr { ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))) } _ => ExprFuncBuilder::new(None), }; @@ -950,7 +950,7 @@ impl ExprFunctionExt for Expr { fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))); builder.partition_by = Some(partition_by); builder } @@ -961,7 +961,7 @@ impl ExprFunctionExt for Expr { fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(*udwf))); builder.window_frame = Some(window_frame); builder } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index e0235d32292f..8a19c9d04e42 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -78,6 +78,7 @@ pub type AggregateFunctionSimplification = Box< >; /// [crate::udwf::WindowUDFImpl::simplify] simplifier closure +/// /// A closure with two arguments: /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked /// * 'info': [crate::simplify::SimplifyInfo] diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index daf1a1375eac..bb088491e087 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2420,19 +2420,24 @@ impl Window { .iter() .enumerate() .filter_map(|(idx, expr)| { - if let Expr::WindowFunction(WindowFunction { + let Expr::WindowFunction(window_func) = expr else { + return None; + }; + let WindowFunction { fun: WindowFunctionDefinition::WindowUDF(udwf), partition_by, .. - }) = expr - { - // When there is no PARTITION BY, row number will be unique - // across the entire table. - if udwf.name() == "row_number" && partition_by.is_empty() { - return Some(idx + input_len); - } + } = window_func.as_ref() + else { + return None; + }; + // When there is no PARTITION BY, row number will be unique + // across the entire table. + if udwf.name() == "row_number" && partition_by.is_empty() { + Some(idx + input_len) + } else { + None } - None }) .map(|idx| { FunctionalDependence::new(vec![idx], vec![], false) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index eacace5ed046..d2e59c2d9db5 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -89,12 +89,13 @@ impl TreeNode for Expr { (expr, when_then_expr, else_expr).apply_ref_elements(f), Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => (args, filter, order_by).apply_ref_elements(f), - Expr::WindowFunction(WindowFunction { - args, - partition_by, - order_by, - .. - }) => { + Expr::WindowFunction(window_func) => { + let WindowFunction { + args, + partition_by, + order_by, + .. + } = window_func.as_ref(); (args, partition_by, order_by).apply_ref_elements(f) } Expr::InList(InList { expr, list, .. }) => { @@ -222,24 +223,28 @@ impl TreeNode for Expr { ))) })? } - Expr::WindowFunction(WindowFunction { - args, - fun, - partition_by, - order_by, - window_frame, - null_treatment, - }) => (args, partition_by, order_by).map_elements(f)?.update_data( - |(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }, - ), + Expr::WindowFunction(window_func) => { + let WindowFunction { + args, + fun, + partition_by, + order_by, + window_frame, + null_treatment, + } = *window_func; + + (args, partition_by, order_by).map_elements(f)?.update_data( + |(new_args, new_partition_by, new_order_by)| { + Expr::from(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() + }, + ) + } Expr::AggregateFunction(AggregateFunction { args, func, diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 96929ffeb0ed..44e3a4749b50 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -133,7 +133,7 @@ impl WindowUDF { pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) } /// Returns this function's name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 049926fb0bcd..2fde551e47fb 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -588,7 +588,8 @@ pub fn group_window_expr_by_sort_keys( ) -> Result)>> { let mut result = vec![]; window_expr.into_iter().try_for_each(|expr| match &expr { - Expr::WindowFunction( WindowFunction{ partition_by, order_by, .. }) => { + Expr::WindowFunction( window_func) => { + let WindowFunction{ partition_by, order_by, .. } = window_func.as_ref(); let sort_key = generate_sort_key(partition_by, order_by)?; if let Some((_, values)) = result.iter_mut().find( |group: &&mut (WindowSortKey, Vec)| matches!(group, (key, _) if *key == sort_key), @@ -1439,19 +1440,19 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )); @@ -1469,25 +1470,25 @@ mod tests { let age_asc = Sort::new(col("age"), true, true); let name_desc = Sort::new(col("name"), false, true); let created_at_desc = Sort::new(col("created_at"), false, true); - let max1 = Expr::WindowFunction(WindowFunction::new( + let max1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let max2 = Expr::WindowFunction(WindowFunction::new( + let max2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], )); - let min3 = Expr::WindowFunction(WindowFunction::new( + let min3 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], )) .order_by(vec![age_asc.clone(), name_desc.clone()]) .build() .unwrap(); - let sum4 = Expr::WindowFunction(WindowFunction::new( + let sum4 = Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], )) diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 95b6f9dc764f..15fd4fc9ebc2 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -219,7 +219,7 @@ mod tests { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .window(vec![Expr::WindowFunction(WindowFunction::new( + .window(vec![Expr::from(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 85fc9b31bcdd..16025dfb039c 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -528,14 +528,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ), ))) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction(window_function) => { + let WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + } = *window_function; let window_frame = coerce_window_frame(window_frame, self.schema, &order_by)?; @@ -551,7 +552,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }; Ok(Transformed::yes( - Expr::WindowFunction(WindowFunction::new(fun, args)) + Expr::from(WindowFunction::new(fun, args)) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index b7dd391586a1..07b10108a977 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -1941,7 +1941,7 @@ mod tests { fn test_window() -> Result<()> { let table_scan = test_table_scan()?; - let max1 = Expr::WindowFunction(expr::WindowFunction::new( + let max1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], )) @@ -1949,7 +1949,7 @@ mod tests { .build() .unwrap(); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( + let max2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], )); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7cb0e7c2f1f7..24feea270277 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1540,7 +1540,7 @@ mod tests { fn filter_move_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1568,7 +1568,7 @@ mod tests { fn filter_move_complex_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1595,7 +1595,7 @@ mod tests { fn filter_move_partial_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1624,7 +1624,7 @@ mod tests { fn filter_expression_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1654,7 +1654,7 @@ mod tests { fn filter_order_keep_window() -> Result<()> { let table_scan = test_table_scan()?; - let window = Expr::WindowFunction(WindowFunction::new( + let window = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1683,7 +1683,7 @@ mod tests { fn filter_multiple_windows_common_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1694,7 +1694,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1722,7 +1722,7 @@ mod tests { fn filter_multiple_windows_disjoint_partitions() -> Result<()> { let table_scan = test_table_scan()?; - let window1 = Expr::WindowFunction(WindowFunction::new( + let window1 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), @@ -1733,7 +1733,7 @@ mod tests { .build() .unwrap(); - let window2 = Expr::WindowFunction(WindowFunction::new( + let window2 = Expr::from(WindowFunction::new( WindowFunctionDefinition::WindowUDF( datafusion_functions_window::rank::rank_udwf(), ), diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 29f3d7cbda39..a3b960916c23 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -35,11 +35,10 @@ use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarV use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, - WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ - expr::{InList, InSubquery, WindowFunction}, + expr::{InList, InSubquery}, utils::{iter_conjunction, iter_conjunction_owned}, }; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; @@ -1431,15 +1430,14 @@ impl TreeNodeRewriter for Simplifier<'_, S> { (_, expr) => Transformed::no(expr), }, - Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::WindowUDF(ref udwf), - .. - }) => match (udwf.simplify(), expr) { - (Some(simplify_function), Expr::WindowFunction(wf)) => { - Transformed::yes(simplify_function(wf, info)?) + Expr::WindowFunction(ref window_function) => { + match (window_function.simplify(), expr) { + (Some(simplify_function), Expr::WindowFunction(wf)) => { + Transformed::yes(simplify_function(*wf, info)?) + } + (_, expr) => Transformed::no(expr), } - (_, expr) => Transformed::no(expr), - }, + } // // Rules for Between @@ -1893,9 +1891,11 @@ fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result { #[cfg(test)] mod tests { + use super::*; use crate::simplify_expressions::SimplifyContext; use crate::test::test_table_scan_with_name; use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; + use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ function::{ AccumulatorArgs, AggregateFunctionSimplification, @@ -1912,8 +1912,6 @@ mod tests { sync::Arc, }; - use super::*; - // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ @@ -4107,8 +4105,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -4116,8 +4113,7 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + let window_function_expr = Expr::from(WindowFunction::new(udwf, vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index e04a89a03dae..4fab6c5d492e 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -300,7 +300,7 @@ pub fn parse_expr( }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) @@ -317,7 +317,7 @@ pub fn parse_expr( }; let args = parse_exprs(&expr.exprs, registry, codec)?; - Expr::WindowFunction(WindowFunction::new( + Expr::from(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6d1d4f30610c..d4595429011f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -298,15 +298,16 @@ pub fn serialize_expr( expr_type: Some(ExprType::SimilarTo(pb)), } } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }) => { + Expr::WindowFunction(window_function) => { + let expr::WindowFunction { + ref fun, + ref args, + ref partition_by, + ref order_by, + ref window_frame, + // TODO: support null treatment in proto + null_treatment: _, + } = window_function.as_ref(); let (window_function, fun_definition) = match fun { WindowFunctionDefinition::AggregateUDF(aggr_udf) => { let mut buf = Vec::new(); diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9cc7514a0d33..52e864b876a0 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -2288,7 +2288,7 @@ fn roundtrip_window() { let ctx = SessionContext::new(); // 1. without window_frame - let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr1 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2299,7 +2299,7 @@ fn roundtrip_window() { .unwrap(); // 2. with default window_frame - let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr2 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2316,7 +2316,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr3 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], )) @@ -2333,7 +2333,7 @@ fn roundtrip_window() { WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), ); - let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr4 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], )) @@ -2383,7 +2383,7 @@ fn roundtrip_window() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr5 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], )) @@ -2464,7 +2464,7 @@ fn roundtrip_window() { let dummy_window_udf = WindowUDF::from(SimpleWindowUDF::new()); - let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( + let test_expr6 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], )) @@ -2474,7 +2474,7 @@ fn roundtrip_window() { .build() .unwrap(); - let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( + let text_expr7 = Expr::from(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], )) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 1cf3dcb289a6..877bac0dd942 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -315,7 +315,7 @@ impl SqlToRel<'_, S> { }; if let Ok(fun) = self.find_window_func(&name) { - return Expr::WindowFunction(expr::WindowFunction::new( + return Expr::from(expr::WindowFunction::new( fun, self.function_args_to_expr(args, schema, planner_context)?, )) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 72618c2b6ab4..8408e9813550 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -187,14 +187,16 @@ impl Unparser<'_> { } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { + Expr::WindowFunction(window_func) => { + let WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment: _, + } = window_func.as_ref(); + let func_name = fun.name(); let args = self.function_args_to_sql(args)?; @@ -1930,7 +1932,7 @@ mod tests { "count(*) FILTER (WHERE true)", ), ( - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), args: vec![col("col")], partition_by: vec![], @@ -1941,7 +1943,7 @@ mod tests { r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, ), ( - Expr::WindowFunction(WindowFunction { + Expr::from(WindowFunction { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], @@ -2789,7 +2791,7 @@ mod tests { let func = WindowFunctionDefinition::WindowUDF(rank_udwf()); let mut window_func = WindowFunction::new(func, vec![]); window_func.order_by = vec![Sort::new(col("a"), true, true)]; - let expr = Expr::WindowFunction(window_func); + let expr = Expr::from(window_func); let ast = unparser.expr_to_sql(&expr)?; let actual = ast.to_string(); diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index a9838ee68c44..5f8f8b11fb0e 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -240,10 +240,13 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr let all_partition_keys = window_exprs .iter() .map(|expr| match expr { - Expr::WindowFunction(WindowFunction { partition_by, .. }) => Ok(partition_by), + Expr::WindowFunction(window_func) => { + let WindowFunction { partition_by, .. } = window_func.as_ref(); + Ok(partition_by) + } Expr::Alias(Alias { expr, .. }) => match expr.as_ref() { - Expr::WindowFunction(WindowFunction { partition_by, .. }) => { - Ok(partition_by) + Expr::WindowFunction(window_function) => { + Ok(window_function.partition_by()) } expr => exec_err!("Impossibly got non-window expr {expr:?}"), }, diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 89112e3fe84e..5bc9bf8c0c8b 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -2221,7 +2221,7 @@ pub async fn from_window_function( window_frame.regularize_order_bys(&mut order_by)?; - Ok(Expr::WindowFunction(expr::WindowFunction { + Ok(Expr::from(expr::WindowFunction { fun, args: from_substrait_func_args(consumer, &window.arguments, input_schema).await?, partition_by: from_substrait_rex_vec(consumer, &window.partitions, input_schema)