Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support WITHIN GROUP syntax to standardize certain existing aggregate functions #13511

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
a9b901a
Add within group variable to aggregate function and arguments
Garamda Nov 21, 2024
0918000
Merge branch 'main' into support_within_group_for_existing_aggregate_…
Garamda Nov 21, 2024
070a96b
Support within group and disable null handling for ordered set aggreg…
Garamda Jan 21, 2025
3fd92fd
Refactored function to match updated signature
Garamda Jan 21, 2025
4082a78
Modify proto to support within group clause
Garamda Jan 21, 2025
c3be3c6
Modify physical planner and accumulator to support ordered set aggreg…
Garamda Jan 21, 2025
9fd05a3
Support session management for ordered set aggregate functions
Garamda Jan 23, 2025
8518a59
Align code, tests, and examples with changes to aggregate function logic
Garamda Jan 25, 2025
79669d9
Fix typo in existing comments
Garamda Jan 25, 2025
597f4d7
Enhance test
Garamda Jan 27, 2025
d3b483c
Merge branch 'main' into support_within_group_for_existing_aggregate_…
Garamda Jan 28, 2025
a827c9d
Fix bug : handle missing within_group when applying children tree node
Garamda Jan 30, 2025
23bdf70
Change the signature of approx_percentile_cont for consistency
Garamda Jan 30, 2025
97d96ca
Add missing within_group for expr display
Garamda Jan 30, 2025
1b61b5b
Handle edge case when over and within group clause are used together
Garamda Jan 31, 2025
d0fdde3
Apply clippy advice: avoids too many arguments
Garamda Feb 1, 2025
3c8bce3
Add new test cases using descending order
Garamda Feb 1, 2025
be99a35
Apply cargo fmt
Garamda Feb 1, 2025
f9aa1fc
Revert unintended submodule changes
Garamda Feb 5, 2025
d5f0b62
Apply prettier guidance
Garamda Feb 5, 2025
d7f2f59
Apply doc guidance by update_function_doc.sh
Garamda Feb 7, 2025
7ef2139
Merge branch 'main' into support_within_group_for_existing_aggregate_…
Garamda Feb 7, 2025
91565b3
Rollback WITHIN GROUP and related logic after converting it into expr
Garamda Feb 27, 2025
d482bff
Rollback ordered set aggregate functions from session to save same in…
Garamda Feb 27, 2025
005a27c
Convert within group to order by when converting sql to expr
Garamda Feb 28, 2025
1179bc4
Rollback within group from proto
Garamda Feb 28, 2025
e5fc1a4
Utilize within group as order by in functions-aggregate
Garamda Feb 28, 2025
cf4faad
Apply clippy
Garamda Feb 28, 2025
fc7d2bc
Merge branch 'main' into support_within_group_for_existing_aggregate_…
Garamda Feb 28, 2025
5469e39
Convert order by to within group
Garamda Feb 28, 2025
d96b667
Apply cargo fmt
Garamda Feb 28, 2025
293d33e
Remove plain line breaks
Garamda Feb 28, 2025
ecdb21b
Remove duplicated column arg in schema name
Garamda Mar 1, 2025
d65420e
Refactor boolean functions to just return primitive type
Garamda Mar 1, 2025
b6d426a
Make within group necessary in the signature of existing ordered set …
Garamda Mar 1, 2025
4b0c52f
Apply cargo fmt
Garamda Mar 1, 2025
36a732d
Support a single ordering expression in the signature
Garamda Mar 1, 2025
8d6db85
Apply cargo fmt
Garamda Mar 1, 2025
db0355a
Add dataframe function test cases to verify descending ordering
Garamda Mar 1, 2025
37b783e
Apply cargo fmt
Garamda Mar 1, 2025
124d8c5
Apply code reviews
Garamda Mar 5, 2025
3259c95
Update error msg in test as corresponding code changed
Garamda Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datafusion/core/benches/aggregate_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ fn criterion_benchmark(c: &mut Criterion) {
b.iter(|| {
query(
ctx.clone(),
"SELECT utf8, approx_percentile_cont(u64_wide, 0.5, 2500) \
"SELECT utf8, approx_percentile_cont(0.5, 2500) WITHIN GROUP (ORDER BY u64_wide) \
FROM t GROUP BY utf8",
)
})
Expand All @@ -158,7 +158,7 @@ fn criterion_benchmark(c: &mut Criterion) {
b.iter(|| {
query(
ctx.clone(),
"SELECT utf8, approx_percentile_cont(f32, 0.5, 2500) \
"SELECT utf8, approx_percentile_cont(0.5, 2500) WITHIN GROUP (ORDER BY f32) \
FROM t GROUP BY utf8",
)
})
Expand Down
84 changes: 66 additions & 18 deletions datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,29 @@ async fn test_fn_approx_median() -> Result<()> {

#[tokio::test]
async fn test_fn_approx_percentile_cont() -> Result<()> {
let expr = approx_percentile_cont(col("b"), lit(0.5), None);
let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), None);

let expected = [
"+---------------------------------------------+",
"| approx_percentile_cont(test.b,Float64(0.5)) |",
"+---------------------------------------------+",
"| 10 |",
"+---------------------------------------------+",
"+---------------------------------------------------------------------------+",
"| approx_percentile_cont(Float64(0.5)) WITHIN GROUP [test.b ASC NULLS LAST] |",
"+---------------------------------------------------------------------------+",
"| 10 |",
"+---------------------------------------------------------------------------+",
];

let df = create_test_table().await?;
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;

assert_batches_eq!(expected, &batches);

let expr = approx_percentile_cont(col("b").sort(false, false), lit(0.1), None);

let expected = [
"+----------------------------------------------------------------------------+",
"| approx_percentile_cont(Float64(0.1)) WITHIN GROUP [test.b DESC NULLS LAST] |",
"+----------------------------------------------------------------------------+",
"| 100 |",
"+----------------------------------------------------------------------------+",
];

let df = create_test_table().await?;
Expand All @@ -381,27 +396,60 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
None::<&str>,
"arg_2".to_string(),
));
let expr = approx_percentile_cont(col("b"), alias_expr, None);
let expr = approx_percentile_cont(col("b").sort(true, false), alias_expr, None);
let df = create_test_table().await?;
let expected = [
"+--------------------------------------------------------------------+",
"| approx_percentile_cont(arg_2) WITHIN GROUP [test.b ASC NULLS LAST] |",
"+--------------------------------------------------------------------+",
"| 10 |",
"+--------------------------------------------------------------------+",
];
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;

assert_batches_eq!(expected, &batches);

let alias_expr = Expr::Alias(Alias::new(
cast(lit(0.1), DataType::Float32),
None::<&str>,
"arg_2".to_string(),
));
let expr = approx_percentile_cont(col("b").sort(false, false), alias_expr, None);
let df = create_test_table().await?;
let expected = [
"+--------------------------------------+",
"| approx_percentile_cont(test.b,arg_2) |",
"+--------------------------------------+",
"| 10 |",
"+--------------------------------------+",
"+---------------------------------------------------------------------+",
"| approx_percentile_cont(arg_2) WITHIN GROUP [test.b DESC NULLS LAST] |",
"+---------------------------------------------------------------------+",
"| 100 |",
"+---------------------------------------------------------------------+",
];
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;

assert_batches_eq!(expected, &batches);

// with number of centroids set
let expr = approx_percentile_cont(col("b"), lit(0.5), Some(lit(2)));
let expr = approx_percentile_cont(col("b").sort(true, false), lit(0.5), Some(lit(2)));
let expected = [
"+------------------------------------------------------------------------------------+",
"| approx_percentile_cont(Float64(0.5),Int32(2)) WITHIN GROUP [test.b ASC NULLS LAST] |",
"+------------------------------------------------------------------------------------+",
"| 30 |",
"+------------------------------------------------------------------------------------+",
];

let df = create_test_table().await?;
let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;

assert_batches_eq!(expected, &batches);

let expr =
approx_percentile_cont(col("b").sort(false, false), lit(0.1), Some(lit(2)));
let expected = [
"+------------------------------------------------------+",
"| approx_percentile_cont(test.b,Float64(0.5),Int32(2)) |",
"+------------------------------------------------------+",
"| 30 |",
"+------------------------------------------------------+",
"+-------------------------------------------------------------------------------------+",
"| approx_percentile_cont(Float64(0.1),Int32(2)) WITHIN GROUP [test.b DESC NULLS LAST] |",
"+-------------------------------------------------------------------------------------+",
"| 69 |",
"+-------------------------------------------------------------------------------------+",
];

let df = create_test_table().await?;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ pub enum Expr {
/// See also [`ExprFunctionExt`] to set these fields.
///
/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt
///
/// cf. `WITHIN GROUP` is converted to `ORDER BY` internally in `datafusion/sql/src/expr/function.rs`
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
WindowFunction(WindowFunction),
Expand Down
38 changes: 37 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,16 @@ impl AggregateUDF {
self.inner.default_value(data_type)
}

/// See [`AggregateUDFImpl::supports_null_handling_clause`] for more details.
pub fn supports_null_handling_clause(&self) -> bool {
self.inner.supports_null_handling_clause()
}

/// See [`AggregateUDFImpl::is_ordered_set_aggregate`] for more details.
pub fn is_ordered_set_aggregate(&self) -> bool {
self.inner.is_ordered_set_aggregate()
}

/// Returns the documentation for this Aggregate UDF.
///
/// Documentation can be accessed programmatically as well as
Expand Down Expand Up @@ -425,6 +435,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
null_treatment,
} = params;

// exclude the first function argument(= column) in ordered set aggregate function,
// because it is duplicated with the WITHIN GROUP clause in schema name.
let args = if self.is_ordered_set_aggregate() {
&args[1..]
} else {
&args[..]
};

let mut schema_name = String::new();

schema_name.write_fmt(format_args!(
Expand All @@ -443,8 +461,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
};

if let Some(order_by) = order_by {
let clause = match self.is_ordered_set_aggregate() {
true => "WITHIN GROUP",
false => "ORDER BY",
};

schema_name.write_fmt(format_args!(
" ORDER BY [{}]",
" {} [{}]",
clause,
schema_name_from_sorts(order_by)?
))?;
};
Expand Down Expand Up @@ -845,6 +869,18 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
ScalarValue::try_from(data_type)
}

/// If this function supports `[IGNORE NULLS | RESPECT NULLS]` clause, return true
/// If the function does not, return false
fn supports_null_handling_clause(&self) -> bool {
true
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something we need? From what I know, there aren't any aggregate functions that have options for null handling. At the moment, the 2 overrides you have of this both return Some(false), which is what I would consider the default value anyways.

Speaking of which, if we do need this, do we need to return an Optional<bool> or could we just return bool directly?

Copy link
Author

@Garamda Garamda Mar 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some aggregate functions using null handling in current datafusion.
(cf. If this is something we need to discuss/fix, then I can make another git issue. Or, I can refactor it too in this PR. I left this comment because I am not 100% sure about the SQL standard.)

And I refactored the function to just return bool.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was smelling odd, so I dug a bit deeper. I think you've inadvertantly stumbled into something even weirder than you anticipated

The example you've linked is

SELECT FIRST_VALUE(column1) RESPECT NULLS FROM t;

which I don't think is a valid query because first_value should not be an aggregate function, or at the very least the above query is not valid in most SQL dialects. first_value is actually a window function in other engines (eg. Trino, Postgres, MySQL).

If you try running something like

SELECT first_value(column1) FROM t;

against Postgres you get an error like

Query Error: window function first_value requires an OVER clause

dbfiddle

The RESPECT NULLS | IGNORE NULLS options is only a property of certain window functions, hence we shouldn't need to track it for aggregate functions.

I'm going to file a ticket for the above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #15006


/// If this function is ordered-set aggregate function, return true
/// If the function is not, return false
fn is_ordered_set_aggregate(&self) -> bool {
false
}

/// Returns the documentation for this Aggregate UDF.
///
/// Documentation can be accessed programmatically as well as
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ make_udaf_expr_and_func!(
/// APPROX_MEDIAN aggregate expression
#[user_doc(
doc_section(label = "Approximate Functions"),
description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`.",
description = "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(0.5) WITHIN GROUP (ORDER BY x)`.",
syntax_example = "approx_median(expression)",
sql_example = r#"```sql
> SELECT approx_median(column_name) FROM table_name;
Expand Down
54 changes: 43 additions & 11 deletions datafusion/functions-aggregate/src/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use datafusion_common::{
downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err,
Result, ScalarValue,
};
use datafusion_expr::expr::{AggregateFunction, Sort};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
use datafusion_expr::utils::format_state_name;
Expand All @@ -51,29 +52,39 @@ create_func!(ApproxPercentileCont, approx_percentile_cont_udaf);

/// Computes the approximate percentile continuous of a set of numbers
pub fn approx_percentile_cont(
expression: Expr,
within_group: Sort,
percentile: Expr,
centroids: Option<Expr>,
) -> Expr {
let expr = within_group.expr.clone();

let args = if let Some(centroids) = centroids {
vec![expression, percentile, centroids]
vec![expr, percentile, centroids]
} else {
vec![expression, percentile]
vec![expr, percentile]
};
approx_percentile_cont_udaf().call(args)

Expr::AggregateFunction(AggregateFunction::new_udf(
approx_percentile_cont_udaf(),
args,
false,
None,
Some(vec![within_group]),
None,
))
}

#[user_doc(
doc_section(label = "Approximate Functions"),
description = "Returns the approximate percentile of input values using the t-digest algorithm.",
syntax_example = "approx_percentile_cont(expression, percentile, centroids)",
syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)",
sql_example = r#"```sql
> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name;
+-------------------------------------------------+
| approx_percentile_cont(column_name, 0.75, 100) |
+-------------------------------------------------+
| 65.0 |
+-------------------------------------------------+
> SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name;
+-----------------------------------------------------------------------+
| approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
+-----------------------------------------------------------------------+
| 65.0 |
+-----------------------------------------------------------------------+
```"#,
standard_argument(name = "expression",),
argument(
Expand Down Expand Up @@ -130,6 +141,19 @@ impl ApproxPercentileCont {
args: AccumulatorArgs,
) -> Result<ApproxPercentileAccumulator> {
let percentile = validate_input_percentile_expr(&args.exprs[1])?;

let is_descending = args
.ordering_req
.first()
.map(|sort_expr| sort_expr.options.descending)
.unwrap_or(false);

let percentile = if is_descending {
1.0 - percentile
} else {
percentile
};
Comment on lines +151 to +155
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used floating point subtraction instead of actual sorting in reverse order, for conciseness.

If any slight floating point difference is not permitted (even if this branch passed the tests), please let me know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable to me, but I don't have that much experience on the execution side of things.


let tdigest_max_size = if args.exprs.len() == 3 {
Some(validate_input_max_size_expr(&args.exprs[2])?)
} else {
Expand Down Expand Up @@ -292,6 +316,14 @@ impl AggregateUDFImpl for ApproxPercentileCont {
Ok(arg_types[0].clone())
}

fn supports_null_handling_clause(&self) -> bool {
false
}

fn is_ordered_set_aggregate(&self) -> bool {
true
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ make_udaf_expr_and_func!(
#[user_doc(
doc_section(label = "Approximate Functions"),
description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.",
syntax_example = "approx_percentile_cont_with_weight(expression, weight, percentile)",
syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)",
sql_example = r#"```sql
> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name;
+----------------------------------------------------------------------+
| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) |
+----------------------------------------------------------------------+
| 78.5 |
+----------------------------------------------------------------------+
> SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name;
+---------------------------------------------------------------------------------------------+
| approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) |
+---------------------------------------------------------------------------------------------+
| 78.5 |
+---------------------------------------------------------------------------------------------+
```"#,
standard_argument(name = "expression", prefix = "The"),
argument(
Expand Down Expand Up @@ -178,6 +178,14 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight {
self.approx_percentile_cont.state_fields(args)
}

fn supports_null_handling_clause(&self) -> bool {
false
}

fn is_ordered_set_aggregate(&self) -> bool {
true
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -970,8 +970,8 @@ async fn roundtrip_expr_api() -> Result<()> {
stddev_pop(lit(2.2)),
approx_distinct(lit(2)),
approx_median(lit(2)),
approx_percentile_cont(lit(2), lit(0.5), None),
approx_percentile_cont(lit(2), lit(0.5), Some(lit(50))),
approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None),
approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))),
approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)),
grouping(lit(1)),
bit_and(lit(2)),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> {
vec![col("b", &schema)?, lit(0.5)],
)
.schema(Arc::clone(&schema))
.alias("APPROX_PERCENTILE_CONT(b, 0.5)")
.alias("APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY b)")
.build()
.map(Arc::new)?];

Expand Down
Loading