From 39ce55ea199d38789337d508b40d39f563ad41b3 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Tue, 24 Jun 2025 10:06:58 +0200 Subject: [PATCH 01/15] ci: test CI --- ci.test | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 ci.test diff --git a/ci.test b/ci.test new file mode 100644 index 0000000000000..e69de29bb2d1d From 365aa9e7cb5c9d885aebb7d911cc1f181c99ec20 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 31 Aug 2025 22:08:12 +0800 Subject: [PATCH 02/15] chore(deps): bump tracing-subscriber from 0.3.19 to 0.3.20 (#17355) Bumps [tracing-subscriber](https://github.com/tokio-rs/tracing) from 0.3.19 to 0.3.20. - [Release notes](https://github.com/tokio-rs/tracing/releases) - [Commits](https://github.com/tokio-rs/tracing/compare/tracing-subscriber-0.3.19...tracing-subscriber-0.3.20) --- updated-dependencies: - dependency-name: tracing-subscriber dependency-version: 0.3.20 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8ffdb8c6403c1..1012587c0c271 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4235,12 +4235,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -4434,12 +4433,6 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "owo-colors" version = "4.2.1" @@ -6675,9 +6668,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "nu-ansi-term", "sharded-slab", From 21a06898ad65de3f1a7a3a789688d369d67174ef Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Tue, 1 Jul 2025 16:43:37 +0200 Subject: [PATCH 03/15] fix: temporary fix to handle incorrect coalesce (inserted during EnforceDistribution) which later causes an error during EnforceSort (without our patch). The next DataFusion version 46 upgrade does the proper fix, which is to not insert the coalesce in the first place. test: recreating the iox plan: * demonstrate the insertion of coalesce after the use of column estimates, and the removal of the test scenario's forcing of rr repartitioning test: reproducer of SanityCheck failure after EnforceSorting removes the coalesce added in the EnforceDistribution fix: special case to not remove the needed coalesce --- .../enforce_distribution.rs | 332 +++++++++++++++++- .../physical_optimizer/enforce_sorting.rs | 99 +++++- .../tests/physical_optimizer/test_utils.rs | 7 + .../src/enforce_sorting/mod.rs | 6 +- datafusion/physical-optimizer/src/utils.rs | 6 + 5 files changed, 432 insertions(+), 18 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index fd847763124ab..2dce87de00ede 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -23,7 +23,7 @@ use crate::physical_optimizer::test_utils::{ check_integrity, coalesce_partitions_exec, parquet_exec_with_sort, parquet_exec_with_stats, repartition_exec, schema, sort_exec, sort_exec_with_preserve_partitioning, sort_merge_join_exec, - sort_preserving_merge_exec, union_exec, + sort_preserving_merge_exec, trim_plan_display, union_exec, }; use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; @@ -39,10 +39,12 @@ use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::ScalarValue; +use datafusion_common::{assert_contains, ScalarValue}; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_expr::{JoinType, Operator}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::{AggregateUDF, JoinType, Operator}; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::{binary, lit, BinaryExpr, Column, Literal}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ @@ -51,6 +53,7 @@ use datafusion_physical_expr_common::sort_expr::{ use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; use datafusion_physical_optimizer::output_requirements::OutputRequirements; +use datafusion_physical_optimizer::sanity_checker::check_plan_sanity; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, @@ -66,7 +69,7 @@ use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{ - get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, + displayable, get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, Statistics, }; @@ -162,8 +165,8 @@ impl ExecutionPlan for SortRequiredExec { fn execute( &self, _partition: usize, - _context: Arc, - ) -> Result { + _context: Arc, + ) -> Result { unreachable!(); } @@ -237,7 +240,7 @@ fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc, alias_pairs: Vec<(String, String)>, ) -> Arc { @@ -251,6 +254,15 @@ fn projection_exec_with_alias( fn aggregate_exec_with_alias( input: Arc, alias_pairs: Vec<(String, String)>, +) -> Arc { + aggregate_exec_with_aggr_expr_and_alias(input, vec![], alias_pairs) +} + +#[expect(clippy::type_complexity)] +fn aggregate_exec_with_aggr_expr_and_alias( + input: Arc, + aggr_expr: Vec<(Arc, Vec>)>, + alias_pairs: Vec<(String, String)>, ) -> Arc { let schema = schema(); let mut group_by_expr: Vec<(Arc, String)> = vec![]; @@ -271,18 +283,31 @@ fn aggregate_exec_with_alias( .collect::>(); let final_grouping = PhysicalGroupBy::new_single(final_group_by_expr); + let aggr_expr = aggr_expr + .into_iter() + .map(|(udaf, exprs)| { + AggregateExprBuilder::new(udaf.clone(), exprs) + .alias(udaf.name()) + .schema(Arc::clone(&schema)) + .build() + .map(Arc::new) + .unwrap() + }) + .collect::>(); + let filter_exprs = std::iter::repeat_n(None, aggr_expr.len()).collect::>(); + Arc::new( AggregateExec::try_new( AggregateMode::FinalPartitioned, final_grouping, - vec![], - vec![], + aggr_expr.clone(), + filter_exprs.clone(), Arc::new( AggregateExec::try_new( AggregateMode::Partial, group_by, - vec![], - vec![], + aggr_expr, + filter_exprs, input, schema.clone(), ) @@ -439,6 +464,12 @@ impl TestConfig { self } + /// Set batch size. + fn with_batch_size(mut self, batch_size: usize) -> Self { + self.config.execution.batch_size = batch_size; + self + } + /// Perform a series of runs using the current [`TestConfig`], /// assert the expected plan result, /// and return the result plan (for potentional subsequent runs). @@ -2027,6 +2058,285 @@ fn repartition_ignores_union() -> Result<()> { Ok(()) } +fn aggregate_over_union(input: Vec>) -> Arc { + let union = union_exec(input); + let plan = + aggregate_exec_with_alias(union, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +// Aggregate over a union, +// with current testing setup. +// +// It will repartiton twice for an aggregate over a union. +// * repartitions before the partial aggregate. +// * repartitions before the final aggregation. +#[test] +fn repartitions_twice_for_aggregate_after_union() -> Result<()> { + let plan = aggregate_over_union(vec![parquet_exec(); 2]); + + // We get a distribution error without repartitioning. + let err = check_plan_sanity(plan.clone(), &Default::default()).unwrap_err(); + assert_contains!( + err.message(), + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet\"] does not satisfy distribution requirements: HashPartitioned[[a1@0]]). Child-0 output partitioning: UnknownPartitioning(2)" + ); + + // Updated plan (post optimization) will have added RepartitionExecs (btwn union and aggregation). + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default(); + test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +// Aggregate over a union, +// but make the test setup more realistic. +// +// It will repartiton once for an aggregate over a union. +// * repartitions btwn partial & final aggregations. +#[test] +fn repartitions_once_for_aggregate_after_union() -> Result<()> { + // use parquet exec with stats + let plan: Arc = + aggregate_over_union(vec![parquet_exec_with_stats(10000); 2]); + + // We get a distribution error without repartitioning. + let err = check_plan_sanity(plan.clone(), &Default::default()).unwrap_err(); + assert_contains!( + err.message(), + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet\"] does not satisfy distribution requirements: HashPartitioned[[a1@0]]). Child-0 output partitioning: UnknownPartitioning(2)" + ); + + // This removes the forced round-robin repartitioning, + // by no longer hard-coding batch_size=1. + // + // Updated plan (post optimization) will have added only 1 RepartitionExec. + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +/// Same as [`aggregate_over_union`], but with a sort btwn the union and aggregation. +fn aggregate_over_sorted_union( + input: Vec>, +) -> Arc { + let union = union_exec(input); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let sort = sort_exec(sort_key, union); + let plan = aggregate_exec_with_alias(sort, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + // Notice the `ordering_mode=Sorted` on the aggregations. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +/// Same as [`repartitions_once_for_aggregate_after_union`], but adds a sort btwn +/// the union and the aggregate. This changes the outcome: +/// +/// * we no longer get a distribution error. +/// * but we still get repartitioning? +#[test] +fn repartitions_for_aggregate_after_sorted_union() -> Result<()> { + let plan = aggregate_over_sorted_union(vec![parquet_exec_with_stats(10000); 2]); + + // With the sort, there is no distribution error. + let checker = check_plan_sanity(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // It does not repartition on the first run + let expected_after_first_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " UnionExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run( + expected_after_first_run, + plan.clone(), + &DISTRIB_DISTRIB_SORT, + )?; + + // But does repartition on the second run. + let expected_after_second_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " UnionExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected_after_second_run, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +/// Same as [`aggregate_over_sorted_union`], but with a sort btwn the union and aggregation. +fn aggregate_over_sorted_union_projection( + input: Vec>, +) -> Arc { + let union = union_exec(input); + let union_projection = projection_exec_with_alias( + union, + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "value".to_string()), + ], + ); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let sort = sort_exec(sort_key, union_projection); + let plan = aggregate_exec_with_alias(sort, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + // Notice the `ordering_mode=Sorted` on the aggregations. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + "ProjectionExec: expr=[a@0 as a, b@1 as value]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +/// Same as [`repartitions_for_aggregate_after_sorted_union`], but adds a projection +/// as well between the union and aggregate. This change the outcome: +/// +/// * we no longer get repartitioning, and instead get coalescing. +#[test] +fn coalesces_for_aggregate_after_sorted_union_projection() -> Result<()> { + let plan = + aggregate_over_sorted_union_projection(vec![parquet_exec_with_stats(10000); 2]); + + // Same as `repartitions_for_aggregate_after_sorted_union`. No error. + let checker = check_plan_sanity(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // It no longer does a repartition on the first run. + // Instead adds a SPM. + let expected_after_first_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run( + expected_after_first_run, + plan.clone(), + &DISTRIB_DISTRIB_SORT, + )?; + + // Then it removes the SPM, and inserts a coalesace on the second run. + let expected_after_second_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected_after_second_run, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + #[test] fn repartition_through_sort_preserving_merge() -> Result<()> { // sort preserving merge with non-sorted input diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index e31a30cc0883c..8f3cc3191b174 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -17,15 +17,16 @@ use std::sync::Arc; +use crate::physical_optimizer::enforce_distribution::projection_exec_with_alias; use crate::physical_optimizer::test_utils::{ aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition, check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema, create_test_schema2, create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, - projection_exec, repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, - sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, - sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered, - union_exec, RequirementsTestExec, + parquet_exec_with_stats, projection_exec, repartition_exec, schema, sort_exec, + sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec, + sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, + spr_repartition_exec, stream_exec_ordered, union_exec, RequirementsTestExec, }; use arrow::compute::SortOptions; @@ -47,6 +48,9 @@ use datafusion_physical_expr_common::sort_expr::{ }; use datafusion_physical_expr::{Distribution, Partitioning}; use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, NotExpr}; +use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; +use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -2292,6 +2296,93 @@ async fn test_commutativity() -> Result<()> { Ok(()) } +fn single_partition_aggregate( + input: Arc, + alias_pairs: Vec<(String, String)>, +) -> Arc { + let schema = schema(); + let group_by = alias_pairs + .iter() + .map(|(column, alias)| (col(column, &input.schema()).unwrap(), alias.to_string())) + .collect::>(); + let group_by = PhysicalGroupBy::new_single(group_by); + + Arc::new( + AggregateExec::try_new( + AggregateMode::SinglePartitioned, + group_by, + vec![], + vec![], + input, + schema, + ) + .unwrap(), + ) +} + +#[tokio::test] +async fn test_preserve_needed_coalesce() -> Result<()> { + // Input to EnforceSorting, from our test case. + let plan = projection_exec_with_alias( + union_exec(vec![parquet_exec_with_stats(10000); 2]), + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "value".to_string()), + ], + ); + let plan = Arc::new(CoalescePartitionsExec::new(plan)); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let plan: Arc = + single_partition_aggregate(plan, vec![("a".to_string(), "a1".to_string())]); + let plan = sort_exec(sort_key, plan); + + // Starting plan: as in our test case. + assert_eq!( + get_plan_string(&plan), + vec![ + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " AggregateExec: mode=SinglePartitioned, gby=[a@0 as a1], aggr=[]", + " CoalescePartitionsExec", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + let checker = SanityCheckPlan::new().optimize(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // EnforceSorting will remove the coalesce, and add an SPM further up (above the aggregate). + let optimizer = EnforceSorting::new(); + let optimized = optimizer.optimize(plan, &Default::default())?; + assert_eq!( + get_plan_string(&optimized), + vec![ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " AggregateExec: mode=SinglePartitioned, gby=[a@0 as a1], aggr=[]", + " CoalescePartitionsExec", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + // Plan is valid. + let checker = SanityCheckPlan::new(); + let checker = checker.optimize(optimized, &Default::default()); + assert!(checker.is_ok()); + + Ok(()) +} + #[tokio::test] async fn test_coalesce_propagate() -> Result<()> { let schema = create_test_schema()?; diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 7fb0f795f2944..c5b3e27438116 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -509,6 +509,13 @@ pub fn check_integrity(context: PlanContext) -> Result Vec<&str> { + plan.split('\n') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect() +} + // construct a stream partition for test purposes #[derive(Debug)] pub struct TestStreamPartition { diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 8a71b28486a2a..dae0edcfb1716 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -48,8 +48,8 @@ use crate::enforce_sorting::sort_pushdown::{ }; use crate::output_requirements::OutputRequirementExec; use crate::utils::{ - add_sort_above, add_sort_above_with_check, is_coalesce_partitions, is_limit, - is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, + add_sort_above, add_sort_above_with_check, is_aggregation, is_coalesce_partitions, + is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, }; use crate::PhysicalOptimizerRule; @@ -678,7 +678,7 @@ fn remove_bottleneck_in_subplan( ) -> Result { let plan = &requirements.plan; let children = &mut requirements.children; - if is_coalesce_partitions(&children[0].plan) { + if is_coalesce_partitions(&children[0].plan) && !is_aggregation(plan) { // We can safely use the 0th index since we have a `CoalescePartitionsExec`. let mut new_child_node = children[0].children.swap_remove(0); while new_child_node.plan.output_partitioning() == plan.output_partitioning() diff --git a/datafusion/physical-optimizer/src/utils.rs b/datafusion/physical-optimizer/src/utils.rs index 3655e555a7440..d3207d4880a70 100644 --- a/datafusion/physical-optimizer/src/utils.rs +++ b/datafusion/physical-optimizer/src/utils.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use datafusion_common::Result; use datafusion_physical_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -113,3 +114,8 @@ pub fn is_repartition(plan: &Arc) -> bool { pub fn is_limit(plan: &Arc) -> bool { plan.as_any().is::() || plan.as_any().is::() } + +/// Checks whether the given operator is a [`AggregateExec`]. +pub fn is_aggregation(plan: &Arc) -> bool { + plan.as_any().is::() +} From d89e3b0ceeab32c9d4d12948e15458f65118ba40 Mon Sep 17 00:00:00 2001 From: Christian van der Loo Date: Wed, 23 Jul 2025 10:11:53 -0400 Subject: [PATCH 04/15] fix(build-wasm): put `arrow-ipc/zstd` dep under `compression` feature flag (#16844) --- Cargo.toml | 1 - datafusion/core/Cargo.toml | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 601d11f12dd81..bb28098104df8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,7 +99,6 @@ arrow-flight = { version = "55.2.0", features = [ ] } arrow-ipc = { version = "55.2.0", default-features = false, features = [ "lz4", - "zstd", ] } arrow-ord = { version = "55.2.0", default-features = false } arrow-schema = { version = "55.2.0", default-features = false } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index c4455e271c84b..1a6a66923e552 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -47,6 +47,7 @@ compression = [ "bzip2", "flate2", "zstd", + "arrow-ipc/zstd", "datafusion-datasource/compression", ] crypto_expressions = ["datafusion-functions/crypto_expressions"] From 458ad08bcad6b01b70140b3be5e0007d07f4d2b8 Mon Sep 17 00:00:00 2001 From: Liam Bao Date: Wed, 6 Aug 2025 06:00:14 -0400 Subject: [PATCH 05/15] Support `centroids` config for `approx_percentile_cont_with_weight` (#17003) * Support centroids config for `approx_percentile_cont_with_weight` * Match two functions' signature * Update docs * Address comments and unify centroids config --- .../src/approx_percentile_cont.rs | 15 ++- .../src/approx_percentile_cont_with_weight.rs | 111 +++++++++++++----- .../tests/cases/roundtrip_logical_plan.rs | 13 +- .../sqllogictest/test_files/aggregate.slt | 10 ++ docs/source/user-guide/expressions.md | 42 +++---- .../user-guide/sql/aggregate_functions.md | 17 ++- 6 files changed, 151 insertions(+), 57 deletions(-) diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 55c8c847ad0a4..863ee15d89ec4 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -77,8 +77,14 @@ pub fn approx_percentile_cont( #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)", + syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql +> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++------------------------------------------------------------------+ +| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++------------------------------------------------------------------+ +| 65.0 | ++------------------------------------------------------------------+ > SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; +-----------------------------------------------------------------------+ | approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | @@ -313,7 +319,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { } if arg_types.len() == 3 && !arg_types[2].is_integer() { return plan_err!( - "approx_percentile_cont requires integer max_size input types" + "approx_percentile_cont requires integer centroids input types" ); } Ok(arg_types[0].clone()) @@ -360,6 +366,11 @@ impl ApproxPercentileAccumulator { } } + // public for approx_percentile_cont_with_weight + pub(crate) fn max_size(&self) -> usize { + self.digest.max_size() + } + // public for approx_percentile_cont_with_weight pub fn merge_digests(&mut self, digests: &[TDigest]) { let digests = digests.iter().chain(std::iter::once(&self.digest)); diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index ab847e8388691..d30ea624cae90 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -25,32 +25,53 @@ use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType}; use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, -}; -use datafusion_functions_aggregate_common::tdigest::{ - Centroid, TDigest, DEFAULT_MAX_SIZE, + Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, }; +use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest}; use datafusion_macros::user_doc; use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont}; -make_udaf_expr_and_func!( +create_func!( ApproxPercentileContWithWeight, - approx_percentile_cont_with_weight, - expression weight percentile, - "Computes the approximate percentile continuous with weight of a set of numbers", approx_percentile_cont_with_weight_udaf ); +/// Computes the approximate percentile continuous with weight of a set of numbers +pub fn approx_percentile_cont_with_weight( + order_by: Sort, + weight: Expr, + percentile: Expr, + centroids: Option, +) -> Expr { + let expr = order_by.expr.clone(); + + let args = if let Some(centroids) = centroids { + vec![expr, weight, percentile, centroids] + } else { + vec![expr, weight, percentile] + }; + + Expr::AggregateFunction(AggregateFunction::new_udf( + approx_percentile_cont_with_weight_udaf(), + args, + false, + None, + vec![order_by], + None, + )) +} + /// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)", + syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql > SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name; +---------------------------------------------------------------------------------------------+ @@ -58,6 +79,12 @@ make_udaf_expr_and_func!( +---------------------------------------------------------------------------------------------+ | 78.5 | +---------------------------------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++--------------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) | ++--------------------------------------------------------------------------------------------------+ +| 78.5 | ++--------------------------------------------------------------------------------------------------+ ```"#, standard_argument(name = "expression", prefix = "The"), argument( @@ -67,6 +94,10 @@ make_udaf_expr_and_func!( argument( name = "percentile", description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)." + ), + argument( + name = "centroids", + description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] pub struct ApproxPercentileContWithWeight { @@ -91,21 +122,26 @@ impl Default for ApproxPercentileContWithWeight { impl ApproxPercentileContWithWeight { /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. pub fn new() -> Self { + let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); + // Accept any numeric value paired with weight and float64 percentile + for num in NUMERICS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + num.clone(), + DataType::Float64, + ])); + // Additionally accept an integer number of centroids for T-Digest + for int in INTEGERS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + num.clone(), + DataType::Float64, + int.clone(), + ])); + } + } Self { - signature: Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![ - t.clone(), - t.clone(), - DataType::Float64, - ]) - }) - .collect(), - Immutable, - ), + signature: Signature::one_of(variants, Immutable), approx_percentile_cont: ApproxPercentileCont::new(), } } @@ -138,6 +174,11 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { if arg_types[2] != DataType::Float64 { return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); } + if arg_types.len() == 4 && !arg_types[3].is_integer() { + return plan_err!( + "approx_percentile_cont_with_weight requires integer centroids input types" + ); + } Ok(arg_types[0].clone()) } @@ -148,17 +189,25 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { ); } - if acc_args.exprs.len() != 3 { + if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 { return plan_err!( - "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile" + "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]" ); } let sub_args = AccumulatorArgs { - exprs: &[ - Arc::clone(&acc_args.exprs[0]), - Arc::clone(&acc_args.exprs[2]), - ], + exprs: if acc_args.exprs.len() == 4 { + &[ + Arc::clone(&acc_args.exprs[0]), // value + Arc::clone(&acc_args.exprs[2]), // percentile + Arc::clone(&acc_args.exprs[3]), // centroids + ] + } else { + &[ + Arc::clone(&acc_args.exprs[0]), // value + Arc::clone(&acc_args.exprs[2]), // percentile + ] + }, ..acc_args }; let approx_percentile_cont_accumulator = @@ -244,7 +293,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { let mut digests: Vec = vec![]; for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) { digests.push(TDigest::new_with_centroid( - DEFAULT_MAX_SIZE, + self.approx_percentile_cont_accumulator.max_size(), Centroid::new(*mean, *weight), )) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 6c51d553fe166..b56fdc0fede65 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -981,7 +981,18 @@ async fn roundtrip_expr_api() -> Result<()> { approx_median(lit(2)), approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None), approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))), - approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), + approx_percentile_cont_with_weight( + lit(2).sort(true, false), + lit(1), + lit(0.5), + None, + ), + approx_percentile_cont_with_weight( + lit(2).sort(true, false), + lit(1), + lit(0.5), + Some(lit(50)), + ), grouping(lit(1)), bit_and(lit(2)), bit_or(lit(2)), diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 753820b6b6193..122907d831816 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1790,6 +1790,16 @@ c 123 d 124 e 115 +# approx_percentile_cont_with_weight with centroids +query TI +SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 74 +b 68 +c 123 +d 124 +e 115 + # csv_query_sum_crossjoin query TTI SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1 diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 03ab86eeb813a..abf0286fa85bd 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -285,27 +285,27 @@ select log(-1), log(0), sqrt(-1); ## Aggregate Functions -| Syntax | Description | -| ----------------------------------------------------------------- | --------------------------------------------------------------------------------------- | -| avg(expr) | Сalculates the average value for `expr`. | -| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | -| approx_median(expr) | Calculates an approximation of the median for `expr`. | -| approx_percentile_cont(expr, percentile) | Calculates an approximation of the specified `percentile` for `expr`. | -| approx_percentile_cont_with_weight(expr, weight_expr, percentile) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. | -| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | -| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | -| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | -| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | -| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | -| count(expr) | Returns the number of rows for `expr`. | -| count_distinct | Creates an expression to represent the count(distinct) aggregate function | -| cube(exprs) | Creates a grouping set for all combination of `exprs` | -| grouping_set(exprs) | Create a grouping set. | -| max(expr) | Finds the maximum value of `expr`. | -| median(expr) | Сalculates the median of `expr`. | -| min(expr) | Finds the minimum value of `expr`. | -| rollup(exprs) | Creates a grouping set for rollup sets. | -| sum(expr) | Сalculates the sum of `expr`. | +| Syntax | Description | +| ------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| avg(expr) | Сalculates the average value for `expr`. | +| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | +| approx_median(expr) | Calculates an approximation of the median for `expr`. | +| approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| approx_percentile_cont_with_weight(expr, weight_expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | +| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | +| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | +| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | +| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | +| count(expr) | Returns the number of rows for `expr`. | +| count_distinct | Creates an expression to represent the count(distinct) aggregate function | +| cube(exprs) | Creates a grouping set for all combination of `exprs` | +| grouping_set(exprs) | Create a grouping set. | +| max(expr) | Finds the maximum value of `expr`. | +| median(expr) | Сalculates the median of `expr`. | +| min(expr) | Finds the minimum value of `expr`. | +| rollup(exprs) | Creates a grouping set for rollup sets. | +| sum(expr) | Сalculates the sum of `expr`. | ## Aggregate Function Builder diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 774a4fae6bf32..3c88714b5f101 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -834,7 +834,7 @@ approx_median(expression) Returns the approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) +approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -846,6 +846,12 @@ approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) #### Example ```sql +> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++------------------------------------------------------------------+ +| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++------------------------------------------------------------------+ +| 65.0 | ++------------------------------------------------------------------+ > SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; +-----------------------------------------------------------------------+ | approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | @@ -859,7 +865,7 @@ approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) Returns the weighted approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression) +approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -867,6 +873,7 @@ approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY ex - **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - **weight**: Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators. - **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). +- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory. #### Example @@ -877,4 +884,10 @@ approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY ex +---------------------------------------------------------------------------------------------+ | 78.5 | +---------------------------------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++--------------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) | ++--------------------------------------------------------------------------------------------------+ +| 78.5 | ++--------------------------------------------------------------------------------------------------+ ``` From 3a7cb2fe64cad7f7d06d5ff43586d363a739d20f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 13 Aug 2025 05:44:44 -0700 Subject: [PATCH 06/15] (Re)Support old syntax for `approx_percentile_cont` and `approx_percentile_cont_with_weight` (#16999) * Add sqllogictests * Allow both new and old sytanx for approx_percentile_cont and approx_percentile_cont_with_weight * Update docs * Add documentation and more tests --- .../src/approx_percentile_cont.rs | 19 ++++++++- .../src/approx_percentile_cont_with_weight.rs | 10 +++++ datafusion/sql/src/expr/function.rs | 7 +--- .../sqllogictest/test_files/aggregate.slt | 41 ++++++++++++++++++- .../user-guide/sql/aggregate_functions.md | 29 +++++++++++++ 5 files changed, 99 insertions(+), 7 deletions(-) diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 863ee15d89ec4..fce300e79bea3 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -91,7 +91,24 @@ pub fn approx_percentile_cont( +-----------------------------------------------------------------------+ | 65.0 | +-----------------------------------------------------------------------+ -```"#, +``` +An alternate syntax is also supported: +```sql +> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name; ++-----------------------------------------------+ +| approx_percentile_cont(column_name, 0.75) | ++-----------------------------------------------+ +| 65.0 | ++-----------------------------------------------+ + +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++----------------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++----------------------------------------------------------+ +| 65.0 | ++----------------------------------------------------------+ +``` +"#, standard_argument(name = "expression",), argument( name = "percentile", diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index d30ea624cae90..f70d751a8cb99 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -85,6 +85,16 @@ pub fn approx_percentile_cont_with_weight( +--------------------------------------------------------------------------------------------------+ | 78.5 | +--------------------------------------------------------------------------------------------------+ +``` +An alternative syntax is also supported: + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++--------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++--------------------------------------------------+ +| 78.5 | ++--------------------------------------------------+ ```"#, standard_argument(name = "expression", prefix = "The"), argument( diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index e63ca75d019d0..74d40e145f1c7 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -380,10 +380,6 @@ impl SqlToRel<'_, S> { } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { - if fm.is_ordered_set_aggregate() && within_group.is_empty() { - return plan_err!("WITHIN GROUP clause is required when calling ordered set aggregate function({})", fm.name()); - } - if null_treatment.is_some() && !fm.supports_null_handling_clause() { return plan_err!( "[IGNORE | RESPECT] NULLS are not permitted for {}", @@ -403,7 +399,8 @@ impl SqlToRel<'_, S> { None, )?; - // add target column expression in within group clause to function arguments + // Add the WITHIN GROUP ordering expressions to the front of the argument list + // So function(arg) WITHIN GROUP (ORDER BY x) becomes function(x, arg) if !within_group.is_empty() { args = within_group .iter() diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 122907d831816..4671408349e24 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1287,7 +1287,7 @@ SELECT approx_distinct(c9) AS a, approx_distinct(c9) AS b FROM aggregate_test_10 ## Column `c12` is omitted due to a large relative error (~10%) due to the small ## float values. -#csv_query_approx_percentile_cont (c2) +# csv_query_approx_percentile_cont (c2) query B SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 ---- @@ -1303,6 +1303,23 @@ SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c2) AS D ---- true + +# csv_query_approx_percentile_cont (c2, alternate syntax, should be the same as above) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.1) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.5) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.9) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + # csv_query_approx_percentile_cont (c3) query B SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 @@ -1743,6 +1760,17 @@ c 122 d 124 e 115 + +# csv_query_approx_percentile_cont_with_weight (should be the same as above) +query TI +SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + query TI SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- @@ -1762,6 +1790,17 @@ c 122 d 124 e 115 +# csv_query_approx_percentile_cont_with_weight alternate syntax +query TI +SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + + query TI SELECT c1, approx_percentile_cont_with_weight(1, 0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 3c88714b5f101..4f2f0abe55c9a 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -860,6 +860,24 @@ approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expressi +-----------------------------------------------------------------------+ ``` +An alternate syntax is also supported: + +```sql +> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name; ++-----------------------------------------------+ +| approx_percentile_cont(column_name, 0.75) | ++-----------------------------------------------+ +| 65.0 | ++-----------------------------------------------+ + +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++----------------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++----------------------------------------------------------+ +| 65.0 | ++----------------------------------------------------------+ +``` + ### `approx_percentile_cont_with_weight` Returns the weighted approximate percentile of input values using the t-digest algorithm. @@ -891,3 +909,14 @@ approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROU | 78.5 | +--------------------------------------------------------------------------------------------------+ ``` + +An alternative syntax is also supported: + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++--------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++--------------------------------------------------+ +| 78.5 | ++--------------------------------------------------+ +``` From 14c7724a6db298ef9c1854480a820cdc746c4f40 Mon Sep 17 00:00:00 2001 From: Qi Zhu <821684824@qq.com> Date: Mon, 28 Jul 2025 17:26:49 +0800 Subject: [PATCH 07/15] feat: support distinct for window (#16925) * feat: support distinct for window * fix * fix * fisx * fix unparse * fix test * fix test * easy way * add test * add comments --- datafusion-examples/examples/advanced_udwf.rs | 1 + datafusion/core/src/physical_planner.rs | 2 + .../core/tests/fuzz_cases/window_fuzz.rs | 3 + .../physical_optimizer/enforce_sorting.rs | 1 + .../tests/physical_optimizer/test_utils.rs | 1 + datafusion/expr/src/expr.rs | 27 ++- datafusion/expr/src/expr_fn.rs | 1 + datafusion/expr/src/planner.rs | 1 + datafusion/expr/src/tree_node.rs | 12 ++ datafusion/expr/src/udaf.rs | 42 +++-- datafusion/functions-aggregate/src/count.rs | 163 +++++++++++++++++- datafusion/functions-window/src/planner.rs | 15 +- .../optimizer/src/analyzer/type_coercion.rs | 29 +++- .../src/windows/bounded_window_agg_exec.rs | 1 + datafusion/physical-plan/src/windows/mod.rs | 42 +++-- datafusion/proto/src/logical_plan/to_proto.rs | 1 + .../proto/src/physical_plan/from_proto.rs | 1 + datafusion/sql/src/expr/function.rs | 12 ++ datafusion/sql/src/unparser/expr.rs | 13 +- datafusion/sqllogictest/test_files/window.slt | 79 +++++++++ .../consumer/expr/window_function.rs | 1 + .../producer/expr/window_function.rs | 1 + 22 files changed, 407 insertions(+), 42 deletions(-) diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index f7316ddc1bec0..e0fab7ee9f310 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -199,6 +199,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { order_by: window_function.params.order_by, window_frame: window_function.params.window_frame, null_treatment: window_function.params.null_treatment, + distinct: window_function.params.distinct, }, })) }; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ab123dcceadab..df24c19f78417 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1646,6 +1646,7 @@ pub fn create_window_expr_with_name( order_by, window_frame, null_treatment, + distinct, }, } = window_fun.as_ref(); let physical_args = @@ -1674,6 +1675,7 @@ pub fn create_window_expr_with_name( window_frame, physical_schema, ignore_nulls, + *distinct, ) } other => plan_err!("Invalid window expression '{other:?}'"), diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 316d3ba5a926b..23e3281cf3861 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -288,6 +288,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { Arc::new(window_frame), &extended_schema, false, + false, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![window_expr], @@ -660,6 +661,7 @@ async fn run_window_test( Arc::new(window_frame.clone()), &extended_schema, false, + false, )?], exec1, false, @@ -678,6 +680,7 @@ async fn run_window_test( Arc::new(window_frame.clone()), &extended_schema, false, + false, )?], exec2, search_mode.clone(), diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index 8f3cc3191b174..ef29d51e5d37e 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -3766,6 +3766,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { case.window_frame, input_schema.as_ref(), false, + false, )?; let window_exec = if window_expr.uses_bounded_memory() { Arc::new(BoundedWindowAggExec::try_new( diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index c5b3e27438116..5e2d61e68f8d7 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -265,6 +265,7 @@ pub fn bounded_window_exec_with_partition( Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), false, + false, ) .unwrap(); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0749ff0e98b71..efe8a639087aa 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1131,6 +1131,8 @@ pub struct WindowFunctionParams { pub window_frame: WindowFrame, /// Specifies how NULL value is treated: ignore or respect pub null_treatment: Option, + /// Distinct flag + pub distinct: bool, } impl WindowFunction { @@ -1145,6 +1147,7 @@ impl WindowFunction { order_by: Vec::default(), window_frame: WindowFrame::new(None), null_treatment: None, + distinct: false, }, } } @@ -2291,6 +2294,7 @@ impl NormalizeEq for Expr { partition_by: self_partition_by, order_by: self_order_by, null_treatment: self_null_treatment, + distinct: self_distinct, }, } = left.as_ref(); let WindowFunction { @@ -2302,6 +2306,7 @@ impl NormalizeEq for Expr { partition_by: other_partition_by, order_by: other_order_by, null_treatment: other_null_treatment, + distinct: other_distinct, }, } = other.as_ref(); @@ -2325,6 +2330,7 @@ impl NormalizeEq for Expr { && a.nulls_first == b.nulls_first && a.expr.normalize_eq(&b.expr) }) + && self_distinct == other_distinct } ( Expr::Exists(Exists { @@ -2558,11 +2564,13 @@ impl HashNode for Expr { order_by: _, window_frame, null_treatment, + distinct, }, } = window_fun.as_ref(); fun.hash(state); window_frame.hash(state); null_treatment.hash(state); + distinct.hash(state); } Expr::InList(InList { expr: _expr, @@ -2865,15 +2873,27 @@ impl Display for SchemaDisplay<'_> { order_by, window_frame, null_treatment, + distinct, } = params; + // Write function name and open parenthesis + write!(f, "{fun}(")?; + + // If DISTINCT, emit the keyword + if *distinct { + write!(f, "DISTINCT ")?; + } + + // Write the comma‑separated argument list write!( f, - "{}({})", - fun, + "{}", schema_name_from_exprs_comma_separated_without_space(args)? )?; + // **Close the argument parenthesis** + write!(f, ")")?; + if let Some(null_treatment) = null_treatment { write!(f, " {null_treatment}")?; } @@ -3260,9 +3280,10 @@ impl Display for Expr { order_by, window_frame, null_treatment, + distinct, } = params; - fmt_function(f, &fun.to_string(), false, args, true)?; + fmt_function(f, &fun.to_string(), *distinct, args, true)?; if let Some(nt) = null_treatment { write!(f, "{nt}")?; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c0351a9dcaca9..fab86fe7663d5 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -945,6 +945,7 @@ impl ExprFuncBuilder { window_frame: window_frame .unwrap_or_else(|| WindowFrame::new(has_order_by)), null_treatment, + distinct, }, }) } diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 067c7a94279fe..b04fe32d376e3 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -308,6 +308,7 @@ pub struct RawWindowExpr { pub order_by: Vec, pub window_frame: WindowFrame, pub null_treatment: Option, + pub distinct: bool, } /// Result of planning a raw expr with [`ExprPlanner`] diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f953aec5a1e39..b6f583ca4c746 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -242,10 +242,22 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; (args, partition_by, order_by).map_elements(f)?.update_data( |(new_args, new_partition_by, new_order_by)| { + if distinct { + return Expr::from(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build() + .unwrap(); + } + Expr::from(WindowFunction::new(fun, new_args)) .partition_by(new_partition_by) .order_by(new_order_by) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b6c8eb627c775..15c0dd57ad2c6 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -554,14 +554,25 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { order_by, window_frame, null_treatment, + distinct, } = params; let mut schema_name = String::new(); - schema_name.write_fmt(format_args!( - "{}({})", - self.name(), - schema_name_from_exprs(args)? - ))?; + + // Inject DISTINCT into the schema name when requested + if *distinct { + schema_name.write_fmt(format_args!( + "{}(DISTINCT {})", + self.name(), + schema_name_from_exprs(args)? + ))?; + } else { + schema_name.write_fmt(format_args!( + "{}({})", + self.name(), + schema_name_from_exprs(args)? + ))?; + } if let Some(null_treatment) = null_treatment { schema_name.write_fmt(format_args!(" {null_treatment}"))?; @@ -579,7 +590,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { " ORDER BY [{}]", schema_name_from_sorts(order_by)? ))?; - }; + } schema_name.write_fmt(format_args!(" {window_frame}"))?; @@ -648,15 +659,24 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { order_by, window_frame, null_treatment, + distinct, } = params; let mut display_name = String::new(); - display_name.write_fmt(format_args!( - "{}({})", - self.name(), - expr_vec_fmt!(args) - ))?; + if *distinct { + display_name.write_fmt(format_args!( + "{}(DISTINCT {})", + self.name(), + expr_vec_fmt!(args) + ))?; + } else { + display_name.write_fmt(format_args!( + "{}({})", + self.name(), + expr_vec_fmt!(args) + ))?; + } if let Some(null_treatment) = null_treatment { display_name.write_fmt(format_args!(" {null_treatment}"))?; diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 09904bbad6ec5..7a7c2879aa790 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -31,7 +31,7 @@ use arrow::{ }; use datafusion_common::{ downcast_value, internal_err, not_impl_err, stats::Precision, - utils::expr::COUNT_STAR_EXPANSION, Result, ScalarValue, + utils::expr::COUNT_STAR_EXPANSION, HashMap, Result, ScalarValue, }; use datafusion_expr::{ expr::WindowFunction, @@ -59,6 +59,7 @@ use std::{ ops::BitAnd, sync::Arc, }; + make_udaf_expr_and_func!( Count, count, @@ -406,6 +407,98 @@ impl AggregateUDFImpl for Count { // the same as new values are seen. SetMonotonicity::Increasing } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + if args.is_distinct { + let acc = + SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?; + Ok(Box::new(acc)) + } else { + let acc = CountAccumulator::new(); + Ok(Box::new(acc)) + } + } +} + +// DistinctCountAccumulator does not support retract_batch and sliding window +// this is a specialized accumulator for distinct count that supports retract_batch +// and sliding window. +#[derive(Debug)] +pub struct SlidingDistinctCountAccumulator { + counts: HashMap, + data_type: DataType, +} + +impl SlidingDistinctCountAccumulator { + pub fn try_new(data_type: &DataType) -> Result { + Ok(Self { + counts: HashMap::default(), + data_type: data_type.clone(), + }) + } +} + +impl Accumulator for SlidingDistinctCountAccumulator { + fn state(&mut self) -> Result> { + let keys = self.counts.keys().cloned().collect::>(); + Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable( + keys.as_slice(), + &self.data_type, + ))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + *self.counts.entry(v).or_default() += 1; + } + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + if let Some(cnt) = self.counts.get_mut(&v) { + *cnt -= 1; + if *cnt == 0 { + self.counts.remove(&v); + } + } + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let list_arr = states[0].as_list::(); + for inner in list_arr.iter().flatten() { + for j in 0..inner.len() { + let v = ScalarValue::try_from_array(&*inner, j)?; + *self.counts.entry(v).or_default() += 1; + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.counts.len() as i64))) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + size_of_val(self) + } } #[derive(Debug)] @@ -878,4 +971,72 @@ mod tests { assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); Ok(()) } + + #[test] + fn sliding_distinct_count_accumulator_basic() -> Result<()> { + // Basic update_batch + evaluate functionality + let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + // Create an Int32Array: [1, 2, 2, 3, null] + let values: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(2), + Some(3), + None, + ])); + acc.update_batch(&[values])?; + // Expect distinct values {1,2,3} → count = 3 + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_retract() -> Result<()> { + // Test that retract_batch properly decrements counts + let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?; + // Initial batch: ["a", "b", "a"] + let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")])) + as ArrayRef; + acc.update_batch(&[arr1])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); // {"a","b"} + + // Retract batch: ["a", null, "b"] + let arr2 = + Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef; + acc.retract_batch(&[arr2])?; + // Before: a→2, b→1; after retract a→1, b→0 → b removed; remaining {"a"} + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_merge_states() -> Result<()> { + // Test merging multiple accumulator states with merge_batch + let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + // acc1 sees [1, 2] + acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?; + // acc2 sees [2, 3] + acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?; + // Extract their states as Vec + let state_sv1 = acc1.state()?; + let state_sv2 = acc2.state()?; + // Convert ScalarValue states into Vec, propagating errors + // NOTE we pass `1` because each ScalarValue.to_array produces a 1‑row ListArray + let state_arr1: Vec = state_sv1 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + let state_arr2: Vec = state_sv2 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + // Merge both states into a fresh accumulator + let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + merged.merge_batch(&state_arr1)?; + merged.merge_batch(&state_arr2)?; + // Expect distinct {1,2,3} → count = 3 + assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } } diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index 091737bb9c156..5e3a6bc6336c3 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -41,6 +41,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, } = raw_expr; let origin_expr = Expr::from(WindowFunction { @@ -51,6 +52,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }, }); @@ -68,6 +70,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; let raw_expr = RawWindowExpr { @@ -77,6 +80,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }; // TODO: remove the next line after `Expr::Wildcard` is removed @@ -93,18 +97,23 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, } = raw_expr; - let new_expr = Expr::from(WindowFunction::new( + let mut new_expr_before_build = Expr::from(WindowFunction::new( func_def, vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?; + .null_treatment(null_treatment); + if distinct { + new_expr_before_build = new_expr_before_build.distinct(); + } + + let new_expr = new_expr_before_build.build()?; let new_expr = saved_name.restore(new_expr); return Ok(PlannerResult::Planned(new_expr)); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index a98b0fdcc3d36..e6fc006cb2ff8 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -549,6 +549,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; let window_frame = @@ -565,14 +566,26 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { _ => args, }; - Ok(Transformed::yes( - Expr::from(WindowFunction::new(fun, args)) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?, - )) + if distinct { + Ok(Transformed::yes( + Expr::from(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build()?, + )) + } else { + Ok(Transformed::yes( + Expr::from(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build()?, + )) + } } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index d3335c0e7fe17..4c991544f877b 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1377,6 +1377,7 @@ mod tests { Arc::new(window_frame), &input.schema(), false, + false, )?], input, input_order_mode, diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 5583abfd72a21..085b17cab9bc3 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -103,21 +103,38 @@ pub fn create_window_expr( window_frame: Arc, input_schema: &Schema, ignore_nulls: bool, + distinct: bool, ) -> Result> { Ok(match fun { WindowFunctionDefinition::AggregateUDF(fun) => { - let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) - .schema(Arc::new(input_schema.clone())) - .alias(name) - .with_ignore_nulls(ignore_nulls) - .build() - .map(Arc::new)?; - window_expr_from_aggregate_expr( - partition_by, - order_by, - window_frame, - aggregate, - ) + if distinct { + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .distinct() + .build() + .map(Arc::new)?; + window_expr_from_aggregate_expr( + partition_by, + order_by, + window_frame, + aggregate, + ) + } else { + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .build() + .map(Arc::new)?; + window_expr_from_aggregate_expr( + partition_by, + order_by, + window_frame, + aggregate, + ) + } } WindowFunctionDefinition::WindowUDF(fun) => Arc::new(StandardWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name, ignore_nulls)?, @@ -805,6 +822,7 @@ mod tests { Arc::new(WindowFrame::new(None)), schema.as_ref(), false, + false, )?], blocking_exec, false, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 43afaa0fbe655..f59e97df0d469 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -316,6 +316,7 @@ pub fn serialize_expr( ref window_frame, // TODO: support null treatment in proto null_treatment: _, + distinct: _, }, } = window_fun.as_ref(); let mut buf = Vec::new(); diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 1c60470b2218f..2ed6ec037fc81 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -179,6 +179,7 @@ pub fn parse_physical_window_expr( Arc::new(window_frame), &extended_schema, false, + false, ) } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 74d40e145f1c7..fd0e7dc6e3b91 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -352,6 +352,7 @@ impl SqlToRel<'_, S> { order_by, window_frame, null_treatment, + distinct: function_args.distinct, }; for planner in self.context_provider.get_expr_planners().iter() { @@ -368,8 +369,19 @@ impl SqlToRel<'_, S> { order_by, window_frame, null_treatment, + distinct, } = window_expr; + if distinct { + return Expr::from(expr::WindowFunction::new(func_def, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build(); + } + return Expr::from(expr::WindowFunction::new(func_def, args)) .partition_by(partition_by) .order_by(order_by) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 4ddd5ccccbbd7..4c0dc316615c3 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -18,8 +18,9 @@ use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, CaseWhen, Expr as AstExpr, Function, Ident, Interval, - ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, ValueWithSpan, + self, Array, BinaryOperator, CaseWhen, DuplicateTreatment, Expr as AstExpr, Function, + Ident, Interval, ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, + ValueWithSpan, }; use std::sync::Arc; use std::vec; @@ -198,6 +199,7 @@ impl Unparser<'_> { partition_by, order_by, window_frame, + distinct, .. }, } = window_fun.as_ref(); @@ -256,7 +258,8 @@ impl Unparser<'_> { span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: None, + duplicate_treatment: distinct + .then_some(DuplicateTreatment::Distinct), args, clauses: vec![], }), @@ -339,7 +342,7 @@ impl Unparser<'_> { }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: distinct - .then_some(ast::DuplicateTreatment::Distinct), + .then_some(DuplicateTreatment::Distinct), args, clauses: vec![], }), @@ -2051,6 +2054,7 @@ mod tests { order_by: vec![], window_frame: WindowFrame::new(None), null_treatment: None, + distinct: false, }, }), r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, @@ -2076,6 +2080,7 @@ mod tests { ), ), null_treatment: None, + distinct: false, }, }), r#"count(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 82de11302857a..bed9121eec3fe 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5650,3 +5650,82 @@ WINDOW 3 7 4 11 5 16 + + +# window with distinct operation +statement ok +CREATE TABLE table_test_distinct_count ( + k VARCHAR, + v Int, + time TIMESTAMP WITH TIME ZONE +); + +statement ok +INSERT INTO table_test_distinct_count (k, v, time) VALUES + ('a', 1, '1970-01-01T00:01:00.00Z'), + ('a', 1, '1970-01-01T00:02:00.00Z'), + ('a', 1, '1970-01-01T00:03:00.00Z'), + ('a', 2, '1970-01-01T00:03:00.00Z'), + ('a', 1, '1970-01-01T00:04:00.00Z'), + ('b', 3, '1970-01-01T00:01:00.00Z'), + ('b', 3, '1970-01-01T00:02:00.00Z'), + ('b', 4, '1970-01-01T00:03:00.00Z'), + ('b', 4, '1970-01-01T00:03:00.00Z'); + +query TPII +SELECT + k, + time, + COUNT(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS normal_count, + COUNT(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS distinct_count +FROM table_test_distinct_count +ORDER BY k, time; +---- +a 1970-01-01T00:01:00Z 1 1 +a 1970-01-01T00:02:00Z 2 1 +a 1970-01-01T00:03:00Z 4 2 +a 1970-01-01T00:03:00Z 4 2 +a 1970-01-01T00:04:00Z 4 2 +b 1970-01-01T00:01:00Z 1 1 +b 1970-01-01T00:02:00Z 2 1 +b 1970-01-01T00:03:00Z 4 2 +b 1970-01-01T00:03:00Z 4 2 + + +query TT +EXPLAIN SELECT + k, + time, + COUNT(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS normal_count, + COUNT(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS distinct_count +FROM table_test_distinct_count +ODER BY k, time; +---- +logical_plan +01)Projection: oder.k, oder.time, count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS normal_count, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS distinct_count +02)--WindowAggr: windowExpr=[[count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW]] +03)----SubqueryAlias: oder +04)------TableScan: table_test_distinct_count projection=[k, v, time] +physical_plan +01)ProjectionExec: expr=[k@0 as k, time@2 as time, count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@3 as normal_count, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@4 as distinct_count] +02)--BoundedWindowAggExec: wdw=[count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW], mode=[Sorted] +03)----SortExec: expr=[k@0 ASC NULLS LAST, time@2 ASC NULLS LAST], preserve_partitioning=[true] +04)------CoalesceBatchesExec: target_batch_size=1 +05)--------RepartitionExec: partitioning=Hash([k@0], 2), input_partitions=2 +06)----------DataSourceExec: partitions=2, partition_sizes=[5, 4] diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs index 80b643a547ee6..27f0de84b7a08 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs @@ -112,6 +112,7 @@ pub async fn from_window_function( order_by, window_frame, null_treatment: None, + distinct: false, }, })) } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs index 17e71f2d7c147..94a39e930f1c2 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs @@ -42,6 +42,7 @@ pub fn from_window_function( order_by, window_frame, null_treatment: _, + distinct: _, }, } = window_fn; // function reference From cb19c83db22898a460e837875668418dbdb1574f Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Fri, 5 Sep 2025 13:42:04 +0200 Subject: [PATCH 08/15] fix: return ALL constants in `EquivalenceProperties::constants` (#17404) * test: regression test for #17372 * test: add more direct regression for #17372 * fix: return ALL constants in `EquivalenceProperties::constants` --- .../physical_optimizer/sanity_checker.rs | 81 ++++++++++++++++++- .../src/equivalence/properties/mod.rs | 9 ++- .../src/equivalence/properties/union.rs | 59 ++++++++++++++ 3 files changed, 142 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/sanity_checker.rs b/datafusion/core/tests/physical_optimizer/sanity_checker.rs index 6233f5d09c56e..ce6eb13c86c44 100644 --- a/datafusion/core/tests/physical_optimizer/sanity_checker.rs +++ b/datafusion/core/tests/physical_optimizer/sanity_checker.rs @@ -20,7 +20,8 @@ use std::sync::Arc; use crate::physical_optimizer::test_utils::{ bounded_window_exec, global_limit_exec, local_limit_exec, memory_exec, - repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec, + projection_exec, repartition_exec, sort_exec, sort_expr, sort_expr_options, + sort_merge_join_exec, sort_preserving_merge_exec, union_exec, }; use arrow::compute::SortOptions; @@ -28,8 +29,8 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{JoinType, Result}; -use datafusion_physical_expr::expressions::col; +use datafusion_common::{JoinType, Result, ScalarValue}; +use datafusion_physical_expr::expressions::{col, Literal}; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; @@ -665,3 +666,77 @@ async fn test_sort_merge_join_dist_missing() -> Result<()> { assert_sanity_check(&smj, false); Ok(()) } + +/// A particular edge case. +/// +/// See . +#[tokio::test] +async fn test_union_with_sorts_and_constants() -> Result<()> { + let schema_in = create_test_schema2(); + + let proj_exprs_1 = vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_1".to_owned(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_2".to_owned(), + ), + (col("a", &schema_in).unwrap(), "a".to_owned()), + ]; + let proj_exprs_2 = vec![ + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("foo".to_owned())))) as _, + "const_1".to_owned(), + ), + ( + Arc::new(Literal::new(ScalarValue::Utf8(Some("bar".to_owned())))) as _, + "const_2".to_owned(), + ), + (col("a", &schema_in).unwrap(), "a".to_owned()), + ]; + + let source_1 = memory_exec(&schema_in); + let source_1 = projection_exec(proj_exprs_1.clone(), source_1).unwrap(); + let schema_sources = source_1.schema(); + let ordering_sources: LexOrdering = + [sort_expr("a", &schema_sources).nulls_last()].into(); + let source_1 = sort_exec(ordering_sources.clone(), source_1); + + let source_2 = memory_exec(&schema_in); + let source_2 = projection_exec(proj_exprs_2, source_2).unwrap(); + let source_2 = sort_exec(ordering_sources.clone(), source_2); + + let plan = union_exec(vec![source_1, source_2]); + + let schema_out = plan.schema(); + let ordering_out: LexOrdering = [ + sort_expr("const_1", &schema_out).nulls_last(), + sort_expr("const_2", &schema_out).nulls_last(), + sort_expr("a", &schema_out).nulls_last(), + ] + .into(); + + let plan = sort_preserving_merge_exec(ordering_out, plan); + + let plan_str = displayable(plan.as_ref()).indent(true).to_string(); + let plan_str = plan_str.trim(); + assert_snapshot!( + plan_str, + @r" + SortPreservingMergeExec: [const_1@0 ASC NULLS LAST, const_2@1 ASC NULLS LAST, a@2 ASC NULLS LAST] + UnionExec + SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[foo as const_1, foo as const_2, a@0 as a] + DataSourceExec: partitions=1, partition_sizes=[0] + SortExec: expr=[a@2 ASC NULLS LAST], preserve_partitioning=[false] + ProjectionExec: expr=[foo as const_1, bar as const_2, a@0 as a] + DataSourceExec: partitions=1, partition_sizes=[0] + " + ); + + assert_sanity_check(&plan, true); + + Ok(()) +} diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 6d18d34ca4ded..fed3b78de8019 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -255,10 +255,11 @@ impl EquivalenceProperties { pub fn constants(&self) -> Vec { self.eq_group .iter() - .filter_map(|c| { - c.constant.as_ref().and_then(|across| { - c.canonical_expr() - .map(|expr| ConstExpr::new(Arc::clone(expr), across.clone())) + .flat_map(|c| { + c.iter().filter_map(|expr| { + c.constant + .as_ref() + .map(|across| ConstExpr::new(Arc::clone(expr), across.clone())) }) }) .collect() diff --git a/datafusion/physical-expr/src/equivalence/properties/union.rs b/datafusion/physical-expr/src/equivalence/properties/union.rs index 4f44b9b0c9d4a..efbefd0d39bfb 100644 --- a/datafusion/physical-expr/src/equivalence/properties/union.rs +++ b/datafusion/physical-expr/src/equivalence/properties/union.rs @@ -921,4 +921,63 @@ mod tests { .collect::>(), )) } + + #[test] + fn test_constants_share_values() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("const_1", DataType::Utf8, false), + Field::new("const_2", DataType::Utf8, false), + ])); + + let col_const_1 = col("const_1", &schema)?; + let col_const_2 = col("const_2", &schema)?; + + let literal_foo = ScalarValue::Utf8(Some("foo".to_owned())); + let literal_bar = ScalarValue::Utf8(Some("bar".to_owned())); + + let const_expr_1_foo = ConstExpr::new( + Arc::clone(&col_const_1), + AcrossPartitions::Uniform(Some(literal_foo.clone())), + ); + let const_expr_2_foo = ConstExpr::new( + Arc::clone(&col_const_2), + AcrossPartitions::Uniform(Some(literal_foo.clone())), + ); + let const_expr_2_bar = ConstExpr::new( + Arc::clone(&col_const_2), + AcrossPartitions::Uniform(Some(literal_bar.clone())), + ); + + let mut input1 = EquivalenceProperties::new(Arc::clone(&schema)); + let mut input2 = EquivalenceProperties::new(Arc::clone(&schema)); + + // | Input | Const_1 | Const_2 | + // | ----- | ------- | ------- | + // | 1 | foo | foo | + // | 2 | foo | bar | + input1.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_foo.clone()])?; + input2.add_constants(vec![const_expr_1_foo.clone(), const_expr_2_bar.clone()])?; + + // Calculate union properties + let union_props = calculate_union(vec![input1, input2], schema)?; + + // This should result in: + // const_1 = Uniform("foo") + // const_2 = Heterogeneous + assert_eq!(union_props.constants().len(), 2); + let union_const_1 = &union_props.constants()[0]; + assert!(union_const_1.expr.eq(&col_const_1)); + assert_eq!( + union_const_1.across_partitions, + AcrossPartitions::Uniform(Some(literal_foo)), + ); + let union_const_2 = &union_props.constants()[1]; + assert!(union_const_2.expr.eq(&col_const_2)); + assert_eq!( + union_const_2.across_partitions, + AcrossPartitions::Heterogeneous, + ); + + Ok(()) + } } From 3248a6eb2f1f79960ae6ee2e962a3cb5443182e1 Mon Sep 17 00:00:00 2001 From: Stuart Carnie Date: Sun, 7 Sep 2025 02:05:17 +1000 Subject: [PATCH 09/15] feat: Support binary data types for `SortMergeJoin` `on` clause (#17431) * feat: Support binary data types for `SortMergeJoin` `on` clause * Add sql level tests for merge join on binary keys --------- Co-authored-by: Andrew Lamb --- .../src/joins/sort_merge_join.rs | 155 +++++++++++++++++- .../test_files/sort_merge_join.slt | 58 ++++++- 2 files changed, 209 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 9a68322834866..3708ec4900a0b 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -2503,6 +2503,10 @@ fn compare_join_arrays( DataType::Utf8 => compare_value!(StringArray), DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), @@ -2571,6 +2575,10 @@ fn is_join_arrays_equal( DataType::Utf8 => compare_value!(StringArray), DataType::Utf8View => compare_value!(StringViewArray), DataType::LargeUtf8 => compare_value!(LargeStringArray), + DataType::Binary => compare_value!(BinaryArray), + DataType::BinaryView => compare_value!(BinaryViewArray), + DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), + DataType::LargeBinary => compare_value!(LargeBinaryArray), DataType::Decimal128(..) => compare_value!(Decimal128Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), @@ -2600,7 +2608,8 @@ mod tests { use arrow::array::{ builder::{BooleanBuilder, UInt64Builder}, - BooleanArray, Date32Array, Date64Array, Int32Array, RecordBatch, UInt64Array, + BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray, + Int32Array, RecordBatch, UInt64Array, }; use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; @@ -2694,6 +2703,56 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() } + fn build_binary_table( + a: (&str, &Vec<&[u8]>), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Binary, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(BinaryArray::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_fixed_size_binary_table( + a: (&str, &Vec<&[u8]>), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::FixedSizeBinary(3), false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(FixedSizeBinaryArray::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + /// returns a table with 3 columns of i32 in memory pub fn build_table_i32_nullable( a: (&str, &Vec>), @@ -3932,6 +3991,100 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_binary() -> Result<()> { + let left = build_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b1", &vec![5, 10, 15]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b2", &vec![105, 110, 115]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_fixed_size_binary() -> Result<()> { + let left = build_fixed_size_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b1", &vec![5, 10, 15]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_fixed_size_binary_table( + ( + "a1", + &vec![ + &[0xc0, 0xff, 0xee], + &[0xde, 0xca, 0xde], + &[0xfa, 0xca, 0xde], + ], + ), + ("b2", &vec![105, 110, 115]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, Inner).await?; + + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "#); + Ok(()) + } + #[tokio::test] async fn join_left_sort_order() -> Result<()> { let left = build_table( diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index c17fe8dfc7e6f..ed463333217af 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -833,9 +833,61 @@ t2 as ( 11 14 12 15 -# return sql params back to default values statement ok -set datafusion.optimizer.prefer_hash_join = true; +set datafusion.execution.batch_size = 8192; + +###### +## Tests for Binary, LargeBinary, BinaryView, FixedSizeBinary join keys +###### statement ok -set datafusion.execution.batch_size = 8192; +create table t1(x varchar, id1 int) as values ('aa', 1), ('bb', 2), ('aa', 3), (null, 4), ('ee', 5); + +statement ok +create table t2(y varchar, id2 int) as values ('ee', 10), ('bb', 20), ('cc', 30), ('cc', 40), (null, 50); + +# Binary join keys +query ?I?I +with t1 as (select arrow_cast(x, 'Binary') as x, id1 from t1), + t2 as (select arrow_cast(y, 'Binary') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# LargeBinary join keys +query ?I?I +with t1 as (select arrow_cast(x, 'LargeBinary') as x, id1 from t1), + t2 as (select arrow_cast(y, 'LargeBinary') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# BinaryView join keys +query ?I?I +with t1 as (select arrow_cast(x, 'BinaryView') as x, id1 from t1), + t2 as (select arrow_cast(y, 'BinaryView') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +# FixedSizeBinary join keys +query ?I?I +with t1 as (select arrow_cast(arrow_cast(x, 'Binary'), 'FixedSizeBinary(2)') as x, id1 from t1), + t2 as (select arrow_cast(arrow_cast(y, 'Binary'), 'FixedSizeBinary(2)') as y, id2 from t2) +select * from t1 join t2 on t1.x = t2.y order by id1, id2 +---- +6262 2 6262 20 +6565 5 6565 10 + +statement ok +drop table t1; + +statement ok +drop table t2; + +# return sql params back to default values +statement ok +set datafusion.optimizer.prefer_hash_join = true; From 2365ab068582b952611d7d6d20a739aca30b2543 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 24 Jun 2025 14:54:00 +0200 Subject: [PATCH 10/15] chore: skip order calculation / exponential planning --- .../src/equivalence/properties/union.rs | 40 +++++++++++++++++-- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/equivalence/properties/union.rs b/datafusion/physical-expr/src/equivalence/properties/union.rs index efbefd0d39bfb..8ec2464068efe 100644 --- a/datafusion/physical-expr/src/equivalence/properties/union.rs +++ b/datafusion/physical-expr/src/equivalence/properties/union.rs @@ -67,16 +67,43 @@ fn calculate_union_binary( }) .collect::>(); + // TEMP HACK WORKAROUND + // Revert code from https://github.com/apache/datafusion/pull/12562 + // Context: https://github.com/apache/datafusion/issues/13748 + // Context: https://github.com/influxdata/influxdb_iox/issues/13038 + // Next, calculate valid orderings for the union by searching for prefixes // in both sides. - let mut orderings = UnionEquivalentOrderingBuilder::new(); - orderings.add_satisfied_orderings(&lhs, &rhs)?; - orderings.add_satisfied_orderings(&rhs, &lhs)?; - let orderings = orderings.build(); + let mut orderings = vec![]; + for ordering in lhs.normalized_oeq_class().into_iter() { + let mut ordering: Vec = ordering.into(); + + // Progressively shorten the ordering to search for a satisfied prefix: + while !rhs.ordering_satisfy(ordering.clone())? { + ordering.pop(); + } + // There is a non-trivial satisfied prefix, add it as a valid ordering: + if !ordering.is_empty() { + orderings.push(ordering); + } + } + for ordering in rhs.normalized_oeq_class().into_iter() { + let mut ordering: Vec = ordering.into(); + + // Progressively shorten the ordering to search for a satisfied prefix: + while !lhs.ordering_satisfy(ordering.clone())? { + ordering.pop(); + } + // There is a non-trivial satisfied prefix, add it as a valid ordering: + if !ordering.is_empty() { + orderings.push(ordering); + } + } let mut eq_properties = EquivalenceProperties::new(lhs.schema); eq_properties.add_constants(constants)?; eq_properties.add_orderings(orderings); + Ok(eq_properties) } @@ -122,6 +149,7 @@ struct UnionEquivalentOrderingBuilder { orderings: Vec, } +#[expect(unused)] impl UnionEquivalentOrderingBuilder { fn new() -> Self { Self { orderings: vec![] } @@ -504,6 +532,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_fill_gaps() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) @@ -579,6 +608,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() -> Result<()> { let schema = create_test_schema().unwrap(); @@ -607,6 +637,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_gap_fill_symmetric() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) @@ -658,6 +689,7 @@ mod tests { } #[test] + #[ignore = "InfluxData patch: chore: skip order calculation / exponential planning"] fn test_union_equivalence_properties_constants_middle_desc() -> Result<()> { let schema = create_test_schema().unwrap(); UnionEquivalenceTest::new(&schema) From 7de4de5a5da7cb3db512fac9da9acdf7807deb81 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 16 Jul 2024 12:14:19 -0400 Subject: [PATCH 11/15] (New) Test + workaround for SanityCheck plan --- datafusion/physical-optimizer/src/sanity_checker.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/datafusion/physical-optimizer/src/sanity_checker.rs b/datafusion/physical-optimizer/src/sanity_checker.rs index acc70d39f057b..3cc5319f9e108 100644 --- a/datafusion/physical-optimizer/src/sanity_checker.rs +++ b/datafusion/physical-optimizer/src/sanity_checker.rs @@ -32,6 +32,8 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; use crate::PhysicalOptimizerRule; @@ -135,6 +137,14 @@ pub fn check_plan_sanity( plan.required_input_ordering(), plan.required_input_distribution(), ) { + // TEMP HACK WORKAROUND https://github.com/apache/datafusion/issues/11492 + if child.as_any().downcast_ref::().is_some() { + continue; + } + if child.as_any().downcast_ref::().is_some() { + continue; + } + let child_eq_props = child.equivalence_properties(); if let Some(sort_req) = sort_req { let sort_req = sort_req.into_single(); From 4e7ad0d212e7718004a5cc05954456259ae77858 Mon Sep 17 00:00:00 2001 From: Adam Curtis Date: Tue, 21 Oct 2025 13:33:16 -0400 Subject: [PATCH 12/15] chore: re-enable physical schema check and log a warning instead of error --- datafusion/core/src/physical_planner.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index df24c19f78417..a05aa510f1be5 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -712,10 +712,15 @@ impl DefaultPhysicalPlanner { differences.push(format!("field nullability at index {} [{}]: (physical) {} vs (logical) {}", i, physical_field.name(), physical_field.is_nullable(), logical_field.is_nullable())); } } - return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences - .iter() - .map(|s| format!("\n\t- {s}")) - .join("")); + + log::warn!("Physical input schema should be the same as the one converted from logical input schema, but did not match for logical plan:\n{}", input.display_indent()); + + //influx: temporarily remove error and only log so that we can find a + //reproducer in production + // return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences + // .iter() + // .map(|s| format!("\n\t- {s}")) + // .join("")); } let groups = self.create_grouping_physical_expr( From 17be87bd761770a8cf356090de170641aaf2609e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 26 Sep 2025 11:11:23 -0400 Subject: [PATCH 13/15] [branch-50] Backport change to avoid debug symbols in ci builds to 50.0.0 --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index bb28098104df8..fe6667b7a8b33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -197,6 +197,7 @@ rpath = false strip = false # Retain debug info for flamegraphs [profile.ci] +debug = false inherits = "dev" incremental = false From ee81b1cc652bde6c131973d091b178836692112d Mon Sep 17 00:00:00 2001 From: Denise Wiedl Date: Thu, 18 Sep 2025 23:07:21 +0300 Subject: [PATCH 14/15] Keep aggregate udaf schema names unique when missing an order-by * test: reproducer of bug * fix: make schema names unique for approx_percentile_cont * test: regression test is now resolved --- datafusion/expr/src/udaf.rs | 2 +- .../sqllogictest/test_files/aggregate.slt | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 15c0dd57ad2c6..a0d2b6a96a480 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -459,7 +459,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { // exclude the first function argument(= column) in ordered set aggregate function, // because it is duplicated with the WITHIN GROUP clause in schema name. - let args = if self.is_ordered_set_aggregate() { + let args = if self.is_ordered_set_aggregate() && !order_by.is_empty() { &args[1..] } else { &args[..] diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 4671408349e24..ab31a87b9e35f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1771,6 +1771,29 @@ c 122 d 124 e 115 + +# using approx_percentile_cont on 2 columns with same signature +query TII +SELECT c1, approx_percentile_cont(c2, 0.95) AS c2, approx_percentile_cont(c3, 0.95) AS c3 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 5 73 +b 5 68 +c 5 122 +d 5 124 +e 5 115 + +# error is unique to this UDAF +query TRR +SELECT c1, avg(c2) AS c2, avg(c3) AS c3 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 2.857142857143 -18.333333333333 +b 3.263157894737 -5.842105263158 +c 2.666666666667 -1.333333333333 +d 2.444444444444 25.444444444444 +e 3 40.333333333333 + + + query TI SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- From ed4f0976cb4f8f66aebcc3fe1b08f9721ffc7b12 Mon Sep 17 00:00:00 2001 From: Adam Curtis Date: Wed, 17 Dec 2025 09:44:43 -0500 Subject: [PATCH 15/15] fix: wrap join operators with cooperative() for cancellation support --- datafusion/physical-plan/src/joins/cross_join.rs | 9 +++++---- datafusion/physical-plan/src/joins/hash_join.rs | 5 +++-- datafusion/physical-plan/src/joins/nested_loop_join.rs | 5 +++-- datafusion/physical-plan/src/joins/sort_merge_join.rs | 5 +++-- .../physical-plan/src/joins/symmetric_hash_join.rs | 9 +++++---- 5 files changed, 19 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index b8ea6330a1e2e..b7943f92a82bd 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -25,6 +25,7 @@ use super::utils::{ BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, StatefulStreamResult, }; +use crate::coop::cooperative; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::projection::{ @@ -324,7 +325,7 @@ impl ExecutionPlan for CrossJoinExec { })?; if enforce_batch_size_in_joins { - Ok(Box::pin(CrossJoinStream { + Ok(Box::pin(cooperative(CrossJoinStream { schema: Arc::clone(&self.schema), left_fut, right: stream, @@ -333,9 +334,9 @@ impl ExecutionPlan for CrossJoinExec { state: CrossJoinStreamState::WaitBuildSide, left_data: RecordBatch::new_empty(self.left().schema()), batch_transformer: BatchSplitter::new(batch_size), - })) + }))) } else { - Ok(Box::pin(CrossJoinStream { + Ok(Box::pin(cooperative(CrossJoinStream { schema: Arc::clone(&self.schema), left_fut, right: stream, @@ -344,7 +345,7 @@ impl ExecutionPlan for CrossJoinExec { state: CrossJoinStreamState::WaitBuildSide, left_data: RecordBatch::new_empty(self.left().schema()), batch_transformer: NoopBatchTransformer::new(), - })) + }))) } } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 84ca7ce19f887..b0db0e0d15966 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -33,6 +33,7 @@ use super::{ PartitionMode, SharedBitmapBuilder, }; use super::{JoinOn, JoinOnRef}; +use crate::coop::cooperative; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::joins::join_hash_map::{JoinHashMapU32, JoinHashMapU64}; use crate::projection::{ @@ -880,7 +881,7 @@ impl ExecutionPlan for HashJoinExec { None => self.column_indices.clone(), }; - Ok(Box::pin(HashJoinStream { + Ok(Box::pin(cooperative(HashJoinStream { schema: self.schema(), on_right, filter: self.filter.clone(), @@ -895,7 +896,7 @@ impl ExecutionPlan for HashJoinExec { batch_size, hashes_buffer: vec![], right_side_ordered: self.right.output_ordering().is_some(), - })) + }))) } fn metrics(&self) -> Option { diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 5bb1673d4af26..9117df4c8ce4c 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -30,6 +30,7 @@ use super::utils::{ StatefulStreamResult, }; use crate::common::can_project; +use crate::coop::cooperative; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, @@ -530,7 +531,7 @@ impl ExecutionPlan for NestedLoopJoinExec { None => self.column_indices.clone(), }; - Ok(Box::pin(NestedLoopJoinStream { + Ok(Box::pin(cooperative(NestedLoopJoinStream { schema: self.schema(), filter: self.filter.clone(), join_type: self.join_type, @@ -544,7 +545,7 @@ impl ExecutionPlan for NestedLoopJoinExec { left_data: None, join_result_status: None, intermediate_batch_size: batch_size, - })) + }))) } fn metrics(&self) -> Option { diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 3708ec4900a0b..a615e76ceccd5 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -33,6 +33,7 @@ use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::coop::cooperative; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::expressions::PhysicalSortExpr; use crate::joins::utils::{ @@ -501,7 +502,7 @@ impl ExecutionPlan for SortMergeJoinExec { .register(context.memory_pool()); // create join stream - Ok(Box::pin(SortMergeJoinStream::try_new( + Ok(Box::pin(cooperative(SortMergeJoinStream::try_new( context.session_config().spill_compression(), Arc::clone(&self.schema), self.sort_options.clone(), @@ -516,7 +517,7 @@ impl ExecutionPlan for SortMergeJoinExec { SortMergeJoinMetrics::new(partition, &self.metrics), reservation, context.runtime_env(), - )?)) + )?))) } fn metrics(&self) -> Option { diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 9a8d4cbb66050..255569cbcd5c9 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -33,6 +33,7 @@ use std::task::{Context, Poll}; use std::vec; use crate::common::SharedMemoryReservation; +use crate::coop::cooperative; use crate::execution_plan::{boundedness_from_children, emission_type_from_children}; use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ @@ -533,7 +534,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { } if enforce_batch_size_in_joins { - Ok(Box::pin(SymmetricHashJoinStream { + Ok(Box::pin(cooperative(SymmetricHashJoinStream { left_stream, right_stream, schema: self.schema(), @@ -551,9 +552,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { state: SHJStreamState::PullRight, reservation, batch_transformer: BatchSplitter::new(batch_size), - })) + }))) } else { - Ok(Box::pin(SymmetricHashJoinStream { + Ok(Box::pin(cooperative(SymmetricHashJoinStream { left_stream, right_stream, schema: self.schema(), @@ -571,7 +572,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { state: SHJStreamState::PullRight, reservation, batch_transformer: NoopBatchTransformer::new(), - })) + }))) } }