Skip to content

Commit 1c38aff

Browse files
authored
Fix empty aggregation function count() in Substrait (apache#15345)
* Fix empty aggregation function count() in Substrait * Fix window function function count() with no arguments in Substrait * Add explanatory comments
1 parent 4af5cfc commit 1c38aff

File tree

4 files changed

+328
-15
lines changed

4 files changed

+328
-15
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1975,6 +1975,15 @@ pub async fn from_substrait_agg_func(
19751975

19761976
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;
19771977

1978+
// Datafusion does not support aggregate functions with no arguments, so
1979+
// we inject a dummy argument that does not affect the query, but allows
1980+
// us to bypass this limitation.
1981+
let args = if udaf.name() == "count" && args.is_empty() {
1982+
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
1983+
} else {
1984+
args
1985+
};
1986+
19781987
Ok(Arc::new(Expr::AggregateFunction(
19791988
expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None),
19801989
)))
@@ -2248,11 +2257,19 @@ pub async fn from_window_function(
22482257

22492258
window_frame.regularize_order_bys(&mut order_by)?;
22502259

2260+
// Datafusion does not support aggregate functions with no arguments, so
2261+
// we inject a dummy argument that does not affect the query, but allows
2262+
// us to bypass this limitation.
2263+
let args = if fun.name() == "count" && window.arguments.is_empty() {
2264+
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
2265+
} else {
2266+
from_substrait_func_args(consumer, &window.arguments, input_schema).await?
2267+
};
2268+
22512269
Ok(Expr::WindowFunction(expr::WindowFunction {
22522270
fun,
22532271
params: WindowFunctionParams {
2254-
args: from_substrait_func_args(consumer, &window.arguments, input_schema)
2255-
.await?,
2272+
args,
22562273
partition_by: from_substrait_rex_vec(
22572274
consumer,
22582275
&window.partitions,
@@ -3406,4 +3423,31 @@ mod test {
34063423

34073424
Ok(())
34083425
}
3426+
3427+
#[tokio::test]
3428+
async fn window_function_with_count() -> Result<()> {
3429+
let substrait = substrait::proto::Expression {
3430+
rex_type: Some(substrait::proto::expression::RexType::WindowFunction(
3431+
substrait::proto::expression::WindowFunction {
3432+
function_reference: 0,
3433+
..Default::default()
3434+
},
3435+
)),
3436+
};
3437+
3438+
let mut consumer = test_consumer();
3439+
3440+
let mut extensions = Extensions::default();
3441+
extensions.register_function("count".to_string());
3442+
consumer.extensions = &extensions;
3443+
3444+
match from_substrait_rex(&consumer, &substrait, &DFSchema::empty()).await? {
3445+
Expr::WindowFunction(window_function) => {
3446+
assert_eq!(window_function.params.args.len(), 1)
3447+
}
3448+
_ => panic!("expr was not a WindowFunction"),
3449+
};
3450+
3451+
Ok(())
3452+
}
34093453
}

datafusion/substrait/tests/cases/consumer_integration.rs

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ mod tests {
4242

4343
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
4444
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
45+
ctx.state().create_physical_plan(&plan).await?;
4546
Ok(format!("{}", plan))
4647
}
4748

@@ -50,9 +51,9 @@ mod tests {
5051
let plan_str = tpch_plan_to_string(1).await?;
5152
assert_eq!(
5253
plan_str,
53-
"Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count() AS COUNT_ORDER\
54+
"Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count(Int64(1)) AS COUNT_ORDER\
5455
\n Sort: LINEITEM.L_RETURNFLAG ASC NULLS LAST, LINEITEM.L_LINESTATUS ASC NULLS LAST\
55-
\n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count()]]\
56+
\n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count(Int64(1))]]\
5657
\n Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, LINEITEM.L_QUANTITY, LINEITEM.L_EXTENDEDPRICE, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT), LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + LINEITEM.L_TAX), LINEITEM.L_DISCOUNT\
5758
\n Filter: LINEITEM.L_SHIPDATE <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 0, milliseconds: 10368000 }\")\
5859
\n TableScan: LINEITEM"
@@ -119,9 +120,9 @@ mod tests {
119120
let plan_str = tpch_plan_to_string(4).await?;
120121
assert_eq!(
121122
plan_str,
122-
"Projection: ORDERS.O_ORDERPRIORITY, count() AS ORDER_COUNT\
123+
"Projection: ORDERS.O_ORDERPRIORITY, count(Int64(1)) AS ORDER_COUNT\
123124
\n Sort: ORDERS.O_ORDERPRIORITY ASC NULLS LAST\
124-
\n Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count()]]\
125+
\n Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count(Int64(1))]]\
125126
\n Projection: ORDERS.O_ORDERPRIORITY\
126127
\n Filter: ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-07-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1993-10-01\") AS Date32) AND EXISTS (<subquery>)\
127128
\n Subquery:\
@@ -269,10 +270,10 @@ mod tests {
269270
let plan_str = tpch_plan_to_string(13).await?;
270271
assert_eq!(
271272
plan_str,
272-
"Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count() AS CUSTDIST\
273-
\n Sort: count() DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST\
274-
\n Projection: count(ORDERS.O_ORDERKEY), count()\
275-
\n Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count()]]\
273+
"Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count(Int64(1)) AS CUSTDIST\
274+
\n Sort: count(Int64(1)) DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST\
275+
\n Projection: count(ORDERS.O_ORDERKEY), count(Int64(1))\
276+
\n Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count(Int64(1))]]\
276277
\n Projection: count(ORDERS.O_ORDERKEY)\
277278
\n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY]], aggr=[[count(ORDERS.O_ORDERKEY)]]\
278279
\n Projection: CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY\
@@ -410,10 +411,10 @@ mod tests {
410411
let plan_str = tpch_plan_to_string(21).await?;
411412
assert_eq!(
412413
plan_str,
413-
"Projection: SUPPLIER.S_NAME, count() AS NUMWAIT\
414+
"Projection: SUPPLIER.S_NAME, count(Int64(1)) AS NUMWAIT\
414415
\n Limit: skip=0, fetch=100\
415-
\n Sort: count() DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST\
416-
\n Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count()]]\
416+
\n Sort: count(Int64(1)) DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST\
417+
\n Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count(Int64(1))]]\
417418
\n Projection: SUPPLIER.S_NAME\
418419
\n Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8(\"F\") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS (<subquery>) AND NOT EXISTS (<subquery>) AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"SAUDI ARABIA\")\
419420
\n Subquery:\
@@ -438,9 +439,9 @@ mod tests {
438439
let plan_str = tpch_plan_to_string(22).await?;
439440
assert_eq!(
440441
plan_str,
441-
"Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count() AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL\
442+
"Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count(Int64(1)) AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL\
442443
\n Sort: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) ASC NULLS LAST\
443-
\n Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(), sum(CUSTOMER.C_ACCTBAL)]]\
444+
\n Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(Int64(1)), sum(CUSTOMER.C_ACCTBAL)]]\
444445
\n Projection: substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)), CUSTOMER.C_ACCTBAL\
445446
\n Filter: (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"13\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"31\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"23\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"29\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"30\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"18\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"17\") AS Utf8)) AND CUSTOMER.C_ACCTBAL > (<subquery>) AND NOT EXISTS (<subquery>)\
446447
\n Subquery:\
@@ -455,4 +456,43 @@ mod tests {
455456
);
456457
Ok(())
457458
}
459+
460+
async fn test_plan_to_string(name: &str) -> Result<String> {
461+
let path = format!("tests/testdata/test_plans/{name}");
462+
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
463+
File::open(path).expect("file not found"),
464+
))
465+
.expect("failed to parse json");
466+
467+
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
468+
let plan = from_substrait_plan(&ctx.state(), &proto).await?;
469+
ctx.state().create_physical_plan(&plan).await?;
470+
Ok(format!("{}", plan))
471+
}
472+
473+
#[tokio::test]
474+
async fn test_select_count_from_select_1() -> Result<()> {
475+
let plan_str =
476+
test_plan_to_string("select_count_from_select_1.substrait.json").await?;
477+
478+
assert_eq!(
479+
plan_str,
480+
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\
481+
\n Values: (Int64(0))"
482+
);
483+
Ok(())
484+
}
485+
486+
#[tokio::test]
487+
async fn test_select_window_count() -> Result<()> {
488+
let plan_str = test_plan_to_string("select_window_count.substrait.json").await?;
489+
490+
assert_eq!(
491+
plan_str,
492+
"Projection: count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\
493+
\n WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\
494+
\n TableScan: DATA"
495+
);
496+
Ok(())
497+
}
458498
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
{
2+
"extensionUris": [
3+
{
4+
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml"
5+
}
6+
],
7+
"extensions": [
8+
{
9+
"extensionFunction": {
10+
"functionAnchor": 185,
11+
"name": "count:any"
12+
}
13+
}
14+
],
15+
"relations": [
16+
{
17+
"root": {
18+
"input": {
19+
"aggregate": {
20+
"common": {
21+
"direct": {
22+
}
23+
},
24+
"input": {
25+
"read": {
26+
"common": {
27+
"direct": {
28+
}
29+
},
30+
"baseSchema": {
31+
"names": [
32+
"dummy"
33+
],
34+
"struct": {
35+
"types": [
36+
{
37+
"i64": {
38+
"nullability": "NULLABILITY_REQUIRED"
39+
}
40+
}
41+
],
42+
"nullability": "NULLABILITY_REQUIRED"
43+
}
44+
},
45+
"virtualTable": {
46+
"values": [
47+
{
48+
"fields": [
49+
{
50+
"i64": "0",
51+
"nullable": false
52+
}
53+
]
54+
}
55+
]
56+
}
57+
}
58+
},
59+
"groupings": [
60+
{
61+
"groupingExpressions": [],
62+
"expressionReferences": []
63+
}
64+
],
65+
"measures": [
66+
{
67+
"measure": {
68+
"functionReference": 185,
69+
"args": [],
70+
"sorts": [],
71+
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
72+
"outputType": {
73+
"i64": {
74+
"nullability": "NULLABILITY_REQUIRED"
75+
}
76+
},
77+
"invocation": "AGGREGATION_INVOCATION_ALL",
78+
"arguments": [],
79+
"options": []
80+
}
81+
}
82+
],
83+
"groupingExpressions": []
84+
}
85+
},
86+
"names": [
87+
"count(*)"
88+
]
89+
}
90+
}
91+
]
92+
}

0 commit comments

Comments
 (0)