Skip to content

Commit 24fa033

Browse files
authored
Merge pull request #556 from mspruc/main
Reimplement AVG() for sql-api java platforms
2 parents 7784fd1 + 626d92e commit 24fa033

4 files changed

Lines changed: 588 additions & 568 deletions

File tree

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateAddCols.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626

2727
public class AggregateAddCols implements FunctionDescriptor.SerializableFunction<Record, Record> {
2828
final List<AggregateCall> aggregateCalls;
29-
30-
public AggregateAddCols(final List<AggregateCall> aggregateCalls){
29+
30+
public AggregateAddCols(final List<AggregateCall> aggregateCalls) {
3131
this.aggregateCalls = aggregateCalls;
3232
}
3333

@@ -36,7 +36,7 @@ public Record apply(final Record record) {
3636
final int l = record.size();
3737
final int newRecordSize = l + aggregateCalls.size() + 1;
3838
final Object[] resValues = new Object[newRecordSize];
39-
39+
4040
for (int i = 0; i < l; i++) {
4141
resValues[i] = record.getField(i);
4242
}

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateFunction.java

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.apache.wayang.api.sql.calcite.converter.functions;
1919

20+
import java.util.Arrays;
2021
import java.util.List;
2122
import java.util.Optional;
2223
import java.util.function.BiFunction;
@@ -39,7 +40,7 @@ public AggregateFunction(final List<AggregateCall> aggregateCalls) {
3940
public Record apply(final Record record1, final Record record2) {
4041
final int l = record1.size();
4142
final Object[] resValues = new Object[l];
42-
final boolean countDone = false;
43+
boolean countDone = false;
4344

4445
for (int i = 0; i < l - aggregateCalls.size() - 1; i++) {
4546
resValues[i] = record1.getField(i);
@@ -61,25 +62,33 @@ public Record apply(final Record record1, final Record record2) {
6162
case MAX:
6263
resValues[counter] = this.castAndMap(field1, field2, SqlFunctions::greatest, SqlFunctions::greatest,
6364
SqlFunctions::greatest, SqlFunctions::greatest);
65+
break;
6466
case COUNT:
6567
// since aggregates inject an extra column for counting before,
6668
// see AggregateAddCols. the column we operate on are integer counts,
6769
// which means we can eagerly get the fields as integers and simply sum
6870
assert (field1 instanceof Integer && field2 instanceof Integer)
6971
: "Expected to find integers for count but found: " + field1 + " and " + field2;
70-
Object obj = Integer.class.cast(field1) + Integer.class.cast(field2);
71-
resValues[counter] = obj;
72+
final Object count = Integer.class.cast(field1) + Integer.class.cast(field2);
73+
resValues[counter] = count;
7274
break;
7375
case AVG:
74-
throw new UnsupportedOperationException("Averages not currently supported");
75-
// resValues[counter] = this.castAndMap(field1, field2, null, null, null, null);
76-
// break;
76+
assert (field1 instanceof Integer && field2 instanceof Integer)
77+
: "Expected to find integers for count but found: " + field1 + " and " + field2;
78+
final Object avg = Integer.class.cast(field1) + Integer.class.cast(field2);
79+
80+
resValues[counter] = avg;
81+
82+
if (!countDone) {
83+
resValues[l - 1] = record1.getInt(l - 1) + record2.getInt(l - 1);
84+
countDone = true;
85+
}
86+
break;
7787
default:
7888
throw new IllegalStateException("Unsupported operation: " + aggregateCall.getAggregation().kind);
7989
}
8090
counter++;
8191
}
82-
8392
return new Record(resValues);
8493
}
8594

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateGetResult.java

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818

1919
package org.apache.wayang.api.sql.calcite.converter.functions;
2020

21+
import java.util.Arrays;
2122
import java.util.List;
2223
import java.util.Set;
24+
import java.util.stream.IntStream;
25+
import java.util.stream.Stream;
2326

2427
import org.apache.calcite.rel.core.AggregateCall;
28+
import org.apache.calcite.sql.SqlKind;
2529
import org.apache.wayang.basic.data.Record;
2630
import org.apache.wayang.core.function.FunctionDescriptor;
2731

@@ -36,31 +40,22 @@ public AggregateGetResult(final List<AggregateCall> aggregateCalls, final Set<In
3640

3741
@Override
3842
public Record apply(final Record record) {
39-
final int l = record.size();
40-
final int outputRecordSize = aggregateCallList.size() + groupingfields.size();
41-
final Object[] resValues = new Object[outputRecordSize];
42-
43-
int i = 0;
44-
int j = 0;
45-
for (i = 0; j < groupingfields.size(); i++) {
46-
if (groupingfields.contains(i)) {
47-
resValues[j] = record.getField(i);
48-
j++;
49-
}
50-
}
51-
52-
i = l - aggregateCallList.size() - 1;
53-
for (final AggregateCall aggregateCall : aggregateCallList) {
54-
final String name = aggregateCall.getAggregation().getName();
55-
if (name.equals("AVG")) {
56-
resValues[j] = record.getDouble(i) / record.getDouble(l - 1);
57-
} else {
58-
resValues[j] = record.getField(i);
59-
}
60-
j++;
61-
i++;
62-
}
63-
64-
return new Record(resValues);
43+
final int recordSize = record.size();
44+
final int aggregateCallOffset = recordSize - aggregateCallList.size() - 1;
45+
46+
final Object[] fields = groupingfields.stream()
47+
.map(record::getField)
48+
.toArray();
49+
50+
final Object[] aggregateCallFields = IntStream.range(0, aggregateCallList.size())
51+
.mapToObj(i -> aggregateCallList.get(i).getAggregation().getKind().equals(SqlKind.AVG)
52+
? record.getDouble(i + aggregateCallOffset) / record.getDouble(recordSize - 1)
53+
: record.getField(i + aggregateCallOffset))
54+
.toArray();
55+
56+
final Object[] combinedFields = Stream.concat(Arrays.stream(fields), Arrays.stream(aggregateCallFields))
57+
.toArray();
58+
59+
return new Record(combinedFields);
6560
}
6661
}

0 commit comments

Comments
 (0)