Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support WITHIN GROUP syntax to standardize certain existing aggregate functions #13511

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
a9b901a
Add within group variable to aggregate function and arguments
Garamda Nov 21, 2024
0918000
Merge branch 'main' into support_within_group_for_existing_aggregate_…
Garamda Nov 21, 2024
070a96b
Support within group and disable null handling for ordered set aggreg…
Garamda Jan 21, 2025
3fd92fd
Refactored function to match updated signature
Garamda Jan 21, 2025
4082a78
Modify proto to support within group clause
Garamda Jan 21, 2025
c3be3c6
Modify physical planner and accumulator to support ordered set aggreg…
Garamda Jan 21, 2025
9fd05a3
Support session management for ordered set aggregate functions
Garamda Jan 23, 2025
8518a59
Align code, tests, and examples with changes to aggregate function logic
Garamda Jan 25, 2025
79669d9
Fix typo in existing comments
Garamda Jan 25, 2025
597f4d7
Enhance test
Garamda Jan 27, 2025
d3b483c
Merge branch 'main' into support_within_group_for_existing_aggregate_…
Garamda Jan 28, 2025
a827c9d
Fix bug : handle missing within_group when applying children tree node
Garamda Jan 30, 2025
23bdf70
Change the signature of approx_percentile_cont for consistency
Garamda Jan 30, 2025
97d96ca
Add missing within_group for expr display
Garamda Jan 30, 2025
1b61b5b
Handle edge case when over and within group clause are used together
Garamda Jan 31, 2025
d0fdde3
Apply clippy advice: avoids too many arguments
Garamda Feb 1, 2025
3c8bce3
Add new test cases using descending order
Garamda Feb 1, 2025
be99a35
Apply cargo fmt
Garamda Feb 1, 2025
f9aa1fc
Revert unintended submodule changes
Garamda Feb 5, 2025
d5f0b62
Apply prettier guidance
Garamda Feb 5, 2025
d7f2f59
Apply doc guidance by update_function_doc.sh
Garamda Feb 7, 2025
7ef2139
Merge branch 'main' into support_within_group_for_existing_aggregate_…
Garamda Feb 7, 2025
91565b3
Rollback WITHIN GROUP and related logic after converting it into expr
Garamda Feb 27, 2025
d482bff
Rollback ordered set aggregate functions from session to save same in…
Garamda Feb 27, 2025
005a27c
Convert within group to order by when converting sql to expr
Garamda Feb 28, 2025
1179bc4
Rollback within group from proto
Garamda Feb 28, 2025
e5fc1a4
Utilize within group as order by in functions-aggregate
Garamda Feb 28, 2025
cf4faad
Apply clippy
Garamda Feb 28, 2025
fc7d2bc
Merge branch 'main' into support_within_group_for_existing_aggregate_…
Garamda Feb 28, 2025
5469e39
Convert order by to within group
Garamda Feb 28, 2025
d96b667
Apply cargo fmt
Garamda Feb 28, 2025
293d33e
Remove plain line breaks
Garamda Feb 28, 2025
ecdb21b
Remove duplicated column arg in schema name
Garamda Mar 1, 2025
d65420e
Refactor boolean functions to just return primitive type
Garamda Mar 1, 2025
b6d426a
Make within group necessary in the signature of existing ordered set …
Garamda Mar 1, 2025
4b0c52f
Apply cargo fmt
Garamda Mar 1, 2025
36a732d
Support a single ordering expression in the signature
Garamda Mar 1, 2025
8d6db85
Apply cargo fmt
Garamda Mar 1, 2025
db0355a
Add dataframe function test cases to verify descending ordering
Garamda Mar 1, 2025
37b783e
Apply cargo fmt
Garamda Mar 1, 2025
124d8c5
Apply code reviews
Garamda Mar 5, 2025
3259c95
Update error msg in test as corresponding code changed
Garamda Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ impl AggregateUDFImpl for BetterAvgUdaf {
aggregate_function.filter,
aggregate_function.order_by,
aggregate_function.null_treatment,
aggregate_function.within_group,
)))
};

Expand Down
4 changes: 4 additions & 0 deletions datafusion-examples/examples/sql_frontend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ impl ContextProvider for MyContextProvider {
None
}

fn get_ordered_set_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
None
}

fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
None
}
Expand Down
6 changes: 6 additions & 0 deletions datafusion/catalog/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ pub trait Session: Send + Sync {
/// Return reference to aggregate_functions
fn aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>>;

/// Return reference to ordered_set_aggregate_functions
///
/// Note : ordered_set_aggregate_functions are a subset of aggregate_functions.
fn ordered_set_aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>>;

/// Return reference to window functions
fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>>;

Expand All @@ -132,6 +137,7 @@ impl From<&dyn Session> for TaskContext {
state.config().clone(),
state.scalar_functions().clone(),
state.aggregate_functions().clone(),
state.ordered_set_aggregate_functions().clone(),
state.window_functions().clone(),
state.runtime_env().clone(),
)
Expand Down
13 changes: 13 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,10 @@ impl FunctionRegistry for SessionContext {
self.state.read().udaf(name)
}

fn ordered_set_udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
self.state.read().ordered_set_udaf(name)
}

fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.state.read().udwf(name)
}
Expand All @@ -1595,6 +1599,15 @@ impl FunctionRegistry for SessionContext {
self.state.write().register_udaf(udaf)
}

fn register_ordered_set_udaf(
&mut self,
ordered_set_udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
self.state
.write()
.register_ordered_set_udaf(ordered_set_udaf)
}

fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
self.state.write().register_udwf(udwf)
}
Expand Down
94 changes: 94 additions & 0 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ pub struct SessionState {
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Aggregate functions registered in the context
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
/// Ordered-set aggregate functions registered in the context
///
/// Note : ordered_set_aggregate_functions are a subset of aggregate_functions.
ordered_set_aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
/// Window functions registered in the context
window_functions: HashMap<String, Arc<WindowUDF>>,
/// Deserializer registry for extensions.
Expand Down Expand Up @@ -202,6 +206,7 @@ impl Debug for SessionState {
.field("table_functions", &self.table_functions)
.field("scalar_functions", &self.scalar_functions)
.field("aggregate_functions", &self.aggregate_functions)
.field("ordered_set_aggregate_functions", &self.ordered_set_aggregate_functions)
.field("window_functions", &self.window_functions)
.field("prepared_plans", &self.prepared_plans)
.finish()
Expand Down Expand Up @@ -241,6 +246,10 @@ impl Session for SessionState {
&self.aggregate_functions
}

fn ordered_set_aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
&self.ordered_set_aggregate_functions
}

fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
&self.window_functions
}
Expand Down Expand Up @@ -880,6 +889,13 @@ impl SessionState {
&self.aggregate_functions
}

/// Return reference to ordered_set_aggregate_functions
///
/// Note : ordered_set_aggregate_functions are a subset of aggregate_functions.
pub fn ordered_set_aggregate_functions(&self) -> &HashMap<String, Arc<AggregateUDF>> {
&self.ordered_set_aggregate_functions
}

/// Return reference to window functions
pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
&self.window_functions
Expand Down Expand Up @@ -968,6 +984,7 @@ pub struct SessionStateBuilder {
table_functions: Option<HashMap<String, Arc<TableFunction>>>,
scalar_functions: Option<Vec<Arc<ScalarUDF>>>,
aggregate_functions: Option<Vec<Arc<AggregateUDF>>>,
ordered_set_aggregate_functions: Option<Vec<Arc<AggregateUDF>>>,
window_functions: Option<Vec<Arc<WindowUDF>>>,
serializer_registry: Option<Arc<dyn SerializerRegistry>>,
file_formats: Option<Vec<Arc<dyn FileFormatFactory>>>,
Expand Down Expand Up @@ -998,6 +1015,7 @@ impl SessionStateBuilder {
table_functions: None,
scalar_functions: None,
aggregate_functions: None,
ordered_set_aggregate_functions: None,
window_functions: None,
serializer_registry: None,
file_formats: None,
Expand Down Expand Up @@ -1048,6 +1066,9 @@ impl SessionStateBuilder {
aggregate_functions: Some(
existing.aggregate_functions.into_values().collect_vec(),
),
ordered_set_aggregate_functions: Some(
existing.ordered_set_aggregate_functions.into_values().collect_vec(),
),
window_functions: Some(existing.window_functions.into_values().collect_vec()),
serializer_registry: Some(existing.serializer_registry),
file_formats: Some(existing.file_formats.into_values().collect_vec()),
Expand All @@ -1073,6 +1094,7 @@ impl SessionStateBuilder {
.with_expr_planners(SessionStateDefaults::default_expr_planners())
.with_scalar_functions(SessionStateDefaults::default_scalar_functions())
.with_aggregate_functions(SessionStateDefaults::default_aggregate_functions())
.with_ordered_set_aggregate_functions(SessionStateDefaults::default_ordered_set_aggregate_functions())
.with_window_functions(SessionStateDefaults::default_window_functions())
}

Expand Down Expand Up @@ -1206,6 +1228,17 @@ impl SessionStateBuilder {
self
}

/// Set the map of ordered-set [`AggregateUDF`]s
///
/// Note : ordered_set_aggregate_functions are a subset of aggregate_functions.
pub fn with_ordered_set_aggregate_functions(
mut self,
ordered_set_aggregate_functions: Vec<Arc<AggregateUDF>>,
) -> Self {
self.ordered_set_aggregate_functions = Some(ordered_set_aggregate_functions);
self
}

/// Set the map of [`WindowUDF`]s
pub fn with_window_functions(
mut self,
Expand Down Expand Up @@ -1340,6 +1373,7 @@ impl SessionStateBuilder {
table_functions,
scalar_functions,
aggregate_functions,
ordered_set_aggregate_functions,
window_functions,
serializer_registry,
file_formats,
Expand Down Expand Up @@ -1371,6 +1405,7 @@ impl SessionStateBuilder {
table_functions: table_functions.unwrap_or_default(),
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
ordered_set_aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
serializer_registry: serializer_registry
.unwrap_or(Arc::new(EmptySerializerRegistry)),
Expand Down Expand Up @@ -1411,6 +1446,20 @@ impl SessionStateBuilder {
});
}

if let Some(ordered_set_aggregate_functions) = ordered_set_aggregate_functions {
ordered_set_aggregate_functions
.into_iter()
.for_each(|udaf| {
let existing_ordered_set_udf = state.register_ordered_set_udaf(udaf);
if let Ok(Some(existing_ordered_set_udf)) = existing_ordered_set_udf {
debug!(
"Overwrote an ordered set existing UDF: {}",
existing_ordered_set_udf.name()
);
}
});
}

if let Some(window_functions) = window_functions {
window_functions.into_iter().for_each(|udwf| {
let existing_udf = state.register_udwf(udwf);
Expand Down Expand Up @@ -1610,6 +1659,7 @@ impl Debug for SessionStateBuilder {
.field("table_functions", &self.table_functions)
.field("scalar_functions", &self.scalar_functions)
.field("aggregate_functions", &self.aggregate_functions)
.field("ordered_set_aggregate_functions", &self.ordered_set_aggregate_functions)
.field("window_functions", &self.window_functions)
.finish()
}
Expand Down Expand Up @@ -1696,6 +1746,13 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
self.state.aggregate_functions().get(name).cloned()
}

fn get_ordered_set_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
self.state
.ordered_set_aggregate_functions()
.get(name)
.cloned()
}

fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
self.state.window_functions().get(name).cloned()
}
Expand Down Expand Up @@ -1766,6 +1823,17 @@ impl FunctionRegistry for SessionState {
})
}

fn ordered_set_udaf(
&self,
name: &str,
) -> datafusion_common::Result<Arc<AggregateUDF>> {
let result = self.ordered_set_aggregate_functions.get(name);

result.cloned().ok_or_else(|| {
plan_datafusion_err!("There is no ordered set UDAF named \"{name}\" in the registry")
})
}

fn udwf(&self, name: &str) -> datafusion_common::Result<Arc<WindowUDF>> {
let result = self.window_functions.get(name);

Expand Down Expand Up @@ -1796,6 +1864,18 @@ impl FunctionRegistry for SessionState {
Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
}

fn register_ordered_set_udaf(
&mut self,
ordered_set_udaf: Arc<AggregateUDF>,
) -> datafusion_common::Result<Option<Arc<AggregateUDF>>> {
ordered_set_udaf.aliases().iter().for_each(|alias| {
self.ordered_set_aggregate_functions
.insert(alias.clone(), Arc::clone(&ordered_set_udaf));
});
Ok(self.ordered_set_aggregate_functions
.insert(ordered_set_udaf.name().into(), ordered_set_udaf))
}

fn register_udwf(
&mut self,
udwf: Arc<WindowUDF>,
Expand Down Expand Up @@ -1833,6 +1913,19 @@ impl FunctionRegistry for SessionState {
Ok(udaf)
}

fn deregister_ordered_set_udaf(
&mut self,
name: &str,
) -> datafusion_common::Result<Option<Arc<AggregateUDF>>> {
let ordered_set_udaf = self.ordered_set_aggregate_functions.remove(name);
if let Some(ordered_set_udaf) = &ordered_set_udaf {
for alias in ordered_set_udaf.aliases() {
self.ordered_set_aggregate_functions.remove(alias);
}
}
Ok(ordered_set_udaf)
}

fn deregister_udwf(
&mut self,
name: &str,
Expand Down Expand Up @@ -1895,6 +1988,7 @@ impl From<&SessionState> for TaskContext {
state.config.clone(),
state.scalar_functions.clone(),
state.aggregate_functions.clone(),
state.ordered_set_aggregate_functions.clone(),
state.window_functions.clone(),
Arc::clone(&state.runtime_env),
)
Expand Down
7 changes: 7 additions & 0 deletions datafusion/core/src/execution/session_state_defaults.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ impl SessionStateDefaults {
functions_aggregate::all_default_aggregate_functions()
}

/// returns the list of default ordered-set [`AggregateUDF']'s
///
/// Note : default_ordered_set_aggregate_functions are a subset of default_aggregate_functions.
pub fn default_ordered_set_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
functions_aggregate::all_default_ordered_set_aggregate_functions()
}

/// returns the list of default [`WindowUDF']'s
pub fn default_window_functions() -> Vec<Arc<WindowUDF>> {
functions_window::all_default_window_functions()
Expand Down
16 changes: 12 additions & 4 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
filter,
order_by,
null_treatment,
within_group,
}) => {
let name = if let Some(name) = name {
name
Expand All @@ -1616,13 +1617,20 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
== NullTreatment::IgnoreNulls;

let (agg_expr, filter, order_by) = {
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
let physical_sort_exprs = match within_group {
Some(within_group) => Some(create_physical_sort_exprs(
within_group,
logical_input_schema,
execution_props,
)?),
None => None,
None => match order_by {
Some(order_by) => Some(create_physical_sort_exprs(
order_by,
logical_input_schema,
execution_props,
)?),
None => None,
},
};

let ordering_reqs: LexOrdering =
Expand Down
36 changes: 18 additions & 18 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,14 @@ async fn test_fn_approx_median() -> Result<()> {

#[tokio::test]
async fn test_fn_approx_percentile_cont() -> Result<()> {
let expr = approx_percentile_cont(col("b"), lit(0.5), None);
let expr = approx_percentile_cont(vec![col("b").sort(true, false)], lit(0.5), None);

let expected = [
"+---------------------------------------------+",
"| approx_percentile_cont(test.b,Float64(0.5)) |",
"+---------------------------------------------+",
"| 10 |",
"+---------------------------------------------+",
"+-----------------------------------------------------------------------------+",
"| approx_percentile_cont(test.b,Float64(0.5)) WITHIN GROUP [b ASC NULLS LAST] |",
"+-----------------------------------------------------------------------------+",
"| 10 |",
"+-----------------------------------------------------------------------------+",
];

let df = create_test_table().await?;
Expand All @@ -381,27 +381,27 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
None::<&str>,
"arg_2".to_string(),
));
let expr = approx_percentile_cont(col("b"), alias_expr, None);
let expr = approx_percentile_cont(vec![col("b").sort(true, false)], alias_expr, None);
let df = create_test_table().await?;
let expected = [
"+--------------------------------------+",
"| approx_percentile_cont(test.b,arg_2) |",
"+--------------------------------------+",
"| 10 |",
"+--------------------------------------+",
"+----------------------------------------------------------------------+",
"| approx_percentile_cont(test.b,arg_2) WITHIN GROUP [b ASC NULLS LAST] |",
"+----------------------------------------------------------------------+",
"| 10 |",
"+----------------------------------------------------------------------+",
];
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;

assert_batches_eq!(expected, &batches);

// with number of centroids set
let expr = approx_percentile_cont(col("b"), lit(0.5), Some(lit(2)));
let expr = approx_percentile_cont(vec![col("b").sort(true, false)], lit(0.5), Some(lit(2)));
let expected = [
"+------------------------------------------------------+",
"| approx_percentile_cont(test.b,Float64(0.5),Int32(2)) |",
"+------------------------------------------------------+",
"| 30 |",
"+------------------------------------------------------+",
"+--------------------------------------------------------------------------------------+",
"| approx_percentile_cont(test.b,Float64(0.5),Int32(2)) WITHIN GROUP [b ASC NULLS LAST] |",
"+--------------------------------------------------------------------------------------+",
"| 30 |",
"+--------------------------------------------------------------------------------------+",
];

let df = create_test_table().await?;
Expand Down
Loading
Loading