Skip to content

Commit c708456

Browse files
Aaaaaaronjulianhyde
authored andcommitted
[CALCITE-4345] AggregateCaseToFilterRule throws NullPointerException when converting CASE without ELSE (Jiatao Tao)
For example, 'SUM(CASE WHEN b THEN 1 END)' is equivalent to 'SUM(CASE WHEN b THEN 1 ELSE NULL END)', and both should be converted to 'SUM(1) FILTER (WHERE b)', but before this bug was fixed the former would throw NullPointerException. Close #2225
1 parent add837a commit c708456

File tree

4 files changed

+71
-11
lines changed

4 files changed

+71
-11
lines changed

core/src/main/java/org/apache/calcite/rel/rules/AggregateCaseToFilterRule.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import com.google.common.collect.ImmutableList;
4141

42+
import java.math.BigDecimal;
4243
import java.util.ArrayList;
4344
import java.util.List;
4445
import javax.annotation.Nullable;
@@ -227,8 +228,8 @@ && isThreeArgCase(project.getProjects().get(singleArg))) {
227228
RelCollations.EMPTY, aggregateCall.getType(),
228229
aggregateCall.getName());
229230
} else if (kind == SqlKind.SUM // Case B
230-
&& isIntLiteral(arg1) && RexLiteral.intValue(arg1) == 1
231-
&& isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
231+
&& isIntLiteral(arg1, BigDecimal.ONE)
232+
&& isIntLiteral(arg2, BigDecimal.ZERO)) {
232233

233234
newProjects.add(filter);
234235
final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
@@ -241,8 +242,7 @@ && isIntLiteral(arg2) && RexLiteral.intValue(arg2) == 0) {
241242
} else if ((RexLiteral.isNullLiteral(arg2) // Case A1
242243
&& aggregateCall.getAggregation().allowsFilter())
243244
|| (kind == SqlKind.SUM // Case A2
244-
&& isIntLiteral(arg2)
245-
&& RexLiteral.intValue(arg2) == 0)) {
245+
&& isIntLiteral(arg2, BigDecimal.ZERO))) {
246246
newProjects.add(arg1);
247247
newProjects.add(filter);
248248
return AggregateCall.create(aggregateCall.getAggregation(), false,
@@ -267,9 +267,10 @@ private static boolean isThreeArgCase(final RexNode rexNode) {
267267
&& ((RexCall) rexNode).operands.size() == 3;
268268
}
269269

270-
private static boolean isIntLiteral(final RexNode rexNode) {
270+
private static boolean isIntLiteral(RexNode rexNode, BigDecimal value) {
271271
return rexNode instanceof RexLiteral
272-
&& SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName());
272+
&& SqlTypeName.INT_TYPES.contains(rexNode.getType().getSqlTypeName())
273+
&& value.equals(((RexLiteral) rexNode).getValueAs(BigDecimal.class));
273274
}
274275

275276
/** Rule configuration. */

core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java

+4
Original file line numberDiff line numberDiff line change
@@ -3826,6 +3826,10 @@ public boolean test(Project project) {
38263826
+ " sum(case when deptno = 20 then sal else 0 end) as sum_sal_d20,\n"
38273827
+ " sum(case when deptno = 30 then 1 else 0 end) as count_d30,\n"
38283828
+ " count(case when deptno = 40 then 'x' end) as count_d40,\n"
3829+
+ " sum(case when deptno = 45 then 1 end) as count_d45,\n"
3830+
+ " sum(case when deptno = 50 then 1 else null end) as count_d50,\n"
3831+
+ " sum(case when deptno = 60 then null end) as sum_null_d60,\n"
3832+
+ " sum(case when deptno = 70 then null else 1 end) as sum_null_d70,\n"
38293833
+ " count(case when deptno = 20 then 1 end) as count_d20\n"
38303834
+ "from emp";
38313835
sql(sql).withRule(CoreRules.AGGREGATE_CASE_TO_FILTER).check();

core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml

+9-5
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,25 @@
2727
sum(case when deptno = 20 then sal else 0 end) as sum_sal_d20,
2828
sum(case when deptno = 30 then 1 else 0 end) as count_d30,
2929
count(case when deptno = 40 then 'x' end) as count_d40,
30+
sum(case when deptno = 45 then 1 end) as count_d45,
31+
sum(case when deptno = 50 then 1 else null end) as count_d50,
32+
sum(case when deptno = 60 then null end) as sum_null_d60,
33+
sum(case when deptno = 70 then null else 1 end) as sum_null_d70,
3034
count(case when deptno = 20 then 1 end) as count_d20
3135
from emp]]>
3236
</Resource>
3337
<Resource name="planBefore">
3438
<![CDATA[
35-
LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $1)], SUM_SAL_D10=[SUM($2)], SUM_SAL_D20=[SUM($3)], COUNT_D30=[SUM($4)], COUNT_D40=[COUNT($5)], COUNT_D20=[COUNT($6)])
36-
LogicalProject(SAL=[$5], $f1=[CASE(=($2, 'CLERK'), $7, null:INTEGER)], $f2=[CASE(=($7, 10), $5, null:INTEGER)], $f3=[CASE(=($7, 20), $5, 0)], $f4=[CASE(=($7, 30), 1, 0)], $f5=[CASE(=($7, 40), 'x', null:CHAR(1))], $f6=[CASE(=($7, 20), 1, null:INTEGER)])
39+
LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $1)], SUM_SAL_D10=[SUM($2)], SUM_SAL_D20=[SUM($3)], COUNT_D30=[SUM($4)], COUNT_D40=[COUNT($5)], COUNT_D45=[SUM($6)], COUNT_D50=[SUM($7)], SUM_NULL_D60=[SUM($8)], SUM_NULL_D70=[SUM($9)], COUNT_D20=[COUNT($10)])
40+
LogicalProject(SAL=[$5], $f1=[CASE(=($2, 'CLERK'), $7, null:INTEGER)], $f2=[CASE(=($7, 10), $5, null:INTEGER)], $f3=[CASE(=($7, 20), $5, 0)], $f4=[CASE(=($7, 30), 1, 0)], $f5=[CASE(=($7, 40), 'x', null:CHAR(1))], $f6=[CASE(=($7, 45), 1, null:INTEGER)], $f7=[CASE(=($7, 50), 1, null:INTEGER)], $f8=[null:DECIMAL(19, 9)], $f9=[CASE(=($7, 70), null:INTEGER, 1)], $f10=[CASE(=($7, 20), 1, null:INTEGER)])
3741
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
3842
]]>
3943
</Resource>
4044
<Resource name="planAfter">
4145
<![CDATA[
42-
LogicalProject(SUM_SAL=[$0], COUNT_DISTINCT_CLERK=[$1], SUM_SAL_D10=[$2], SUM_SAL_D20=[$3], COUNT_D30=[CAST($4):INTEGER], COUNT_D40=[$5], COUNT_D20=[$6])
43-
LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $1) FILTER $2], SUM_SAL_D10=[SUM($3) FILTER $4], SUM_SAL_D20=[SUM($5) FILTER $6], COUNT_D30=[COUNT() FILTER $7], COUNT_D40=[COUNT() FILTER $8], COUNT_D20=[COUNT() FILTER $9])
44-
LogicalProject(SAL=[$5], DEPTNO=[$7], $f8=[=($2, 'CLERK')], SAL0=[$5], $f10=[=($7, 10)], SAL1=[$5], $f12=[=($7, 20)], $f13=[=($7, 30)], $f14=[=($7, 40)], $f15=[=($7, 20)])
46+
LogicalProject(SUM_SAL=[$0], COUNT_DISTINCT_CLERK=[$1], SUM_SAL_D10=[$2], SUM_SAL_D20=[$3], COUNT_D30=[CAST($4):INTEGER], COUNT_D40=[$5], COUNT_D45=[$6], COUNT_D50=[$7], SUM_NULL_D60=[$8], SUM_NULL_D70=[$9], COUNT_D20=[$10])
47+
LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $2) FILTER $3], SUM_SAL_D10=[SUM($4) FILTER $5], SUM_SAL_D20=[SUM($6) FILTER $7], COUNT_D30=[COUNT() FILTER $8], COUNT_D40=[COUNT() FILTER $9], COUNT_D45=[SUM($10) FILTER $11], COUNT_D50=[SUM($12) FILTER $13], SUM_NULL_D60=[SUM($1)], SUM_NULL_D70=[SUM($14) FILTER $15], COUNT_D20=[COUNT() FILTER $16])
48+
LogicalProject(SAL=[$5], $f8=[null:DECIMAL(19, 9)], DEPTNO=[$7], $f12=[=($2, 'CLERK')], SAL0=[$5], $f14=[=($7, 10)], SAL1=[$5], $f16=[=($7, 20)], $f17=[=($7, 30)], $f18=[=($7, 40)], $f19=[1], $f20=[=($7, 45)], $f21=[1], $f22=[=($7, 50)], $f23=[1], $f24=[<>($7, 70)], $f25=[=($7, 20)])
4549
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
4650
]]>
4751
</Resource>

core/src/test/resources/sql/agg.iq

+51
Original file line numberDiff line numberDiff line change
@@ -2245,6 +2245,57 @@ EnumerableCalc(expr#0=[{inputs}], expr#1=[0:BIGINT], expr#2=[=($t0, $t1)], expr#
22452245

22462246
!use scott
22472247

2248+
# [CALCITE-4345] SUM(CASE WHEN b THEN 1) etc.
2249+
select
2250+
sum(sal) as sum_sal,
2251+
count(distinct case
2252+
when job = 'CLERK'
2253+
then deptno else null end) as count_distinct_clerk,
2254+
sum(case when deptno = 10 then sal end) as sum_sal_d10,
2255+
sum(case when deptno = 20 then sal else 0 end) as sum_sal_d20,
2256+
sum(case when deptno = 30 then 1 else 0 end) as count_d30,
2257+
count(case when deptno = 40 then 'x' end) as count_d40,
2258+
sum(case when deptno = 45 then 1 end) as count_d45,
2259+
sum(case when deptno = 50 then 1 else null end) as count_d50,
2260+
sum(case when deptno = 60 then null end) as sum_null_d60,
2261+
sum(case when deptno = 70 then null else 1 end) as sum_null_d70,
2262+
count(case when deptno = 20 then 1 end) as count_d20
2263+
from emp;
2264+
+----------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
2265+
| SUM_SAL | COUNT_DISTINCT_CLERK | SUM_SAL_D10 | SUM_SAL_D20 | COUNT_D30 | COUNT_D40 | COUNT_D45 | COUNT_D50 | SUM_NULL_D60 | SUM_NULL_D70 | COUNT_D20 |
2266+
+----------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
2267+
| 29025.00 | 3 | 8750.00 | 10875.00 | 6 | 0 | | | | 14 | 5 |
2268+
+----------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
2269+
(1 row)
2270+
2271+
!ok
2272+
2273+
# Check that SUM produces NULL on empty set, COUNT produces 0.
2274+
select
2275+
sum(sal) as sum_sal,
2276+
count(distinct case
2277+
when job = 'CLERK'
2278+
then deptno else null end) as count_distinct_clerk,
2279+
sum(case when deptno = 10 then sal end) as sum_sal_d10,
2280+
sum(case when deptno = 20 then sal else 0 end) as sum_sal_d20,
2281+
sum(case when deptno = 30 then 1 else 0 end) as count_d30,
2282+
count(case when deptno = 40 then 'x' end) as count_d40,
2283+
sum(case when deptno = 45 then 1 end) as count_d45,
2284+
sum(case when deptno = 50 then 1 else null end) as count_d50,
2285+
sum(case when deptno = 60 then null end) as sum_null_d60,
2286+
sum(case when deptno = 70 then null else 1 end) as sum_null_d70,
2287+
count(case when deptno = 20 then 1 end) as count_d20
2288+
from emp
2289+
where false;
2290+
+---------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
2291+
| SUM_SAL | COUNT_DISTINCT_CLERK | SUM_SAL_D10 | SUM_SAL_D20 | COUNT_D30 | COUNT_D40 | COUNT_D45 | COUNT_D50 | SUM_NULL_D60 | SUM_NULL_D70 | COUNT_D20 |
2292+
+---------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
2293+
| | 0 | | | | 0 | | | | | 0 |
2294+
+---------+----------------------+-------------+-------------+-----------+-----------+-----------+-----------+--------------+--------------+-----------+
2295+
(1 row)
2296+
2297+
!ok
2298+
22482299
# [CALCITE-1930] AggregateExpandDistinctAggregateRules should handle multiple aggregate calls with same input ref
22492300
select count(distinct EMPNO), COUNT(SAL), MIN(SAL), MAX(SAL) from "scott".emp;
22502301
+--------+--------+--------+---------+

0 commit comments

Comments
 (0)