Skip to content

Commit 5207d44

Browse files
committed
serializable aggregations for sql-api
1 parent a3958f8 commit 5207d44

6 files changed

Lines changed: 79 additions & 47 deletions

File tree

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package org.apache.wayang.api.sql.calcite.converter;
2020

21+
import java.util.HashSet;
2122
import java.util.List;
2223
import java.util.Set;
2324

@@ -51,7 +52,7 @@ Operator visit(final WayangAggregate wayangRelNode) {
5152

5253
final List<AggregateCall> aggregateCalls = ((Aggregate) wayangRelNode).getAggCallList();
5354
final int groupCount = wayangRelNode.getGroupCount();
54-
final Set<Integer> groupingFields = wayangRelNode.getGroupSet().asSet();
55+
final HashSet<Integer> groupingFields = new HashSet<>(wayangRelNode.getGroupSet().asSet());
5556

5657
final MapOperator<Record, Record> mapOperator = new MapOperator<>(
5758
new AggregateAddCols(aggregateCalls),

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateAddCols.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,19 @@
1818
package org.apache.wayang.api.sql.calcite.converter.functions;
1919

2020
import java.util.List;
21+
import java.util.stream.Collectors;
2122

2223
import org.apache.calcite.rel.core.AggregateCall;
23-
24+
import org.apache.calcite.sql.SqlKind;
2425
import org.apache.wayang.basic.data.Record;
26+
import org.apache.wayang.basic.data.Tuple2;
2527
import org.apache.wayang.core.function.FunctionDescriptor;
2628

2729
public class AggregateAddCols implements FunctionDescriptor.SerializableFunction<Record, Record> {
28-
final List<AggregateCall> aggregateCalls;
30+
final List<Tuple2<SqlKind, List<Integer>>> aggregateCalls;
2931

3032
public AggregateAddCols(final List<AggregateCall> aggregateCalls) {
31-
this.aggregateCalls = aggregateCalls;
33+
this.aggregateCalls = aggregateCalls.stream().map(call -> new Tuple2<>(call.getAggregation().getKind(), call.getArgList())).collect(Collectors.toList());
3234
}
3335

3436
@Override
@@ -42,13 +44,16 @@ public Record apply(final Record record) {
4244
}
4345

4446
int i = l;
45-
for (final AggregateCall aggregateCall : aggregateCalls) {
46-
switch (aggregateCall.getAggregation().kind) {
47+
for (final Tuple2<SqlKind, List<Integer>> aggregateCall : aggregateCalls) {
48+
final SqlKind kind = aggregateCall.field0;
49+
final List<Integer> argList = aggregateCall.field1;
50+
51+
switch (kind) {
4752
case COUNT:
4853
resValues[i] = 1;
4954
break;
5055
default:
51-
resValues[i] = record.getField(aggregateCall.getArgList().get(0));
56+
resValues[i] = record.getField(argList.get(0));
5257
}
5358
i++;
5459
}

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateFunction.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,22 @@
2121
import java.util.List;
2222
import java.util.Optional;
2323
import java.util.function.BiFunction;
24+
import java.util.stream.Collectors;
2425

2526
import org.apache.calcite.rel.core.AggregateCall;
2627
import org.apache.calcite.runtime.SqlFunctions;
28+
import org.apache.calcite.sql.SqlKind;
2729
import org.apache.wayang.basic.data.Record;
2830
import org.apache.wayang.core.function.FunctionDescriptor;
2931

3032
public class AggregateFunction
3133
implements FunctionDescriptor.SerializableBinaryOperator<Record> {
32-
33-
final List<AggregateCall> aggregateCalls;
34+
final List<SqlKind> aggregateKinds;
3435

3536
public AggregateFunction(final List<AggregateCall> aggregateCalls) {
36-
this.aggregateCalls = aggregateCalls;
37+
this.aggregateKinds = aggregateCalls.stream()
38+
.map(call -> call.getAggregation().getKind())
39+
.collect(Collectors.toList());
3740
}
3841

3942
@Override
@@ -42,16 +45,16 @@ public Record apply(final Record record1, final Record record2) {
4245
final Object[] resValues = new Object[l];
4346
boolean countDone = false;
4447

45-
for (int i = 0; i < l - aggregateCalls.size() - 1; i++) {
48+
for (int i = 0; i < l - aggregateKinds.size() - 1; i++) {
4649
resValues[i] = record1.getField(i);
4750
}
4851

49-
int counter = l - aggregateCalls.size() - 1;
50-
for (final AggregateCall aggregateCall : aggregateCalls) {
52+
int counter = l - aggregateKinds.size() - 1;
53+
for (final SqlKind kind : aggregateKinds) {
5154
final Object field1 = record1.getField(counter);
5255
final Object field2 = record2.getField(counter);
5356

54-
switch (aggregateCall.getAggregation().kind) {
57+
switch (kind) {
5558
case SUM:
5659
resValues[counter] = this.castAndMap(field1, field2, null, Long::sum, Integer::sum, Double::sum);
5760
break;
@@ -85,7 +88,7 @@ public Record apply(final Record record1, final Record record2) {
8588
}
8689
break;
8790
default:
88-
throw new IllegalStateException("Unsupported operation: " + aggregateCall.getAggregation().kind);
91+
throw new IllegalStateException("Unsupported operation: " + kind);
8992
}
9093
counter++;
9194
}

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateGetResult.java

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.Arrays;
2222
import java.util.List;
2323
import java.util.Set;
24+
import java.util.stream.Collectors;
2425
import java.util.stream.IntStream;
2526
import java.util.stream.Stream;
2627

@@ -30,32 +31,35 @@
3031
import org.apache.wayang.core.function.FunctionDescriptor;
3132

3233
public class AggregateGetResult implements FunctionDescriptor.SerializableFunction<Record, Record> {
33-
private final List<AggregateCall> aggregateCallList;
34-
private final Set<Integer> groupingfields;
35-
36-
public AggregateGetResult(final List<AggregateCall> aggregateCalls, final Set<Integer> groupingfields) {
37-
this.aggregateCallList = aggregateCalls;
38-
this.groupingfields = groupingfields;
39-
}
40-
41-
@Override
42-
public Record apply(final Record record) {
43-
final int recordSize = record.size();
44-
final int aggregateCallOffset = recordSize - aggregateCallList.size() - 1;
45-
46-
final Object[] fields = groupingfields.stream()
47-
.map(record::getField)
48-
.toArray();
49-
50-
final Object[] aggregateCallFields = IntStream.range(0, aggregateCallList.size())
51-
.mapToObj(i -> aggregateCallList.get(i).getAggregation().getKind().equals(SqlKind.AVG)
52-
? record.getDouble(i + aggregateCallOffset) / record.getDouble(recordSize - 1)
53-
: record.getField(i + aggregateCallOffset))
54-
.toArray();
55-
56-
final Object[] combinedFields = Stream.concat(Arrays.stream(fields), Arrays.stream(aggregateCallFields))
57-
.toArray();
58-
59-
return new Record(combinedFields);
60-
}
34+
private final List<SqlKind> aggregateKindList;
35+
private final Set<Integer> groupingfields;
36+
37+
public AggregateGetResult(final List<AggregateCall> aggregateCalls, final Set<Integer> groupingfields) {
38+
this.aggregateKindList = aggregateCalls.stream()
39+
.map(call -> call.getAggregation().getKind())
40+
.collect(Collectors.toList());
41+
this.groupingfields = groupingfields;
42+
}
43+
44+
@Override
45+
public Record apply(final Record record) {
46+
final int recordSize = record.size();
47+
final int aggregateCallOffset = recordSize - aggregateKindList.size() - 1;
48+
49+
final Object[] fields = groupingfields.stream()
50+
.map(record::getField)
51+
.toArray();
52+
53+
final Object[] aggregateCallFields = IntStream.range(0, aggregateKindList.size())
54+
.mapToObj(i -> aggregateKindList.get(i).equals(SqlKind.AVG)
55+
? record.getDouble(i + aggregateCallOffset)
56+
/ record.getDouble(recordSize - 1)
57+
: record.getField(i + aggregateCallOffset))
58+
.toArray();
59+
60+
final Object[] combinedFields = Stream.concat(Arrays.stream(fields), Arrays.stream(aggregateCallFields))
61+
.toArray();
62+
63+
return new Record(combinedFields);
64+
}
6165
}

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateKeyExtractor.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@
1919
package org.apache.wayang.api.sql.calcite.converter.functions;
2020

2121
import java.util.ArrayList;
22+
import java.util.HashSet;
2223
import java.util.List;
23-
import java.util.Set;
2424

2525
import org.apache.wayang.basic.data.Record;
2626
import org.apache.wayang.core.function.FunctionDescriptor;
2727

2828
public class AggregateKeyExtractor implements FunctionDescriptor.SerializableFunction<Record, Object> {
29-
private final Set<Integer> indexSet;
29+
private final HashSet<Integer> indexSet;
3030

31-
public AggregateKeyExtractor(final Set<Integer> indexSet) {
31+
public AggregateKeyExtractor(final HashSet<Integer> indexSet) {
3232
this.indexSet = indexSet;
3333
}
3434

wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ public void javaLimit() throws Exception {
358358

359359
final List<Record> result = r.stream().collect(Collectors.toList());
360360

361-
System.out.println("limit srot: " + result);
362361
assert (result.size() == 1);
363362
assert (result.get(0).equals(new Record(2, "a", "a", 2)));
364363
}
@@ -479,6 +478,26 @@ public void sparkFilter() throws Exception {
479478
assert (result.stream().anyMatch(rec -> rec.equals(new Record("test1", "test1"))));
480479
}
481480

481+
@Test
482+
public void sparkAggregate() throws Exception {
483+
final SqlContext sqlContext = this.createSqlContext("/data/largeLeftTableIndex.csv");
484+
final Tuple2<Collection<Record>, WayangPlan> t = this.buildCollectorAndWayangPlan(sqlContext,
485+
"SELECT largeLeftTableIndex.NAMEC, COUNT(*) FROM fs.largeLeftTableIndex GROUP BY NAMEC");
486+
final Collection<Record> result = t.field0;
487+
final WayangPlan wayangPlan = t.field1;
488+
489+
// except reduce by
490+
PlanTraversal.upstream().traverse(wayangPlan.getSinks()).getTraversedNodes().forEach(node -> {
491+
node.addTargetPlatform(Spark.platform());
492+
});
493+
494+
sqlContext.execute(wayangPlan);
495+
496+
final Record rec = result.stream().findFirst().get();
497+
assert (rec.size() == 2);
498+
assert (rec.getInt(1) == 3);
499+
}
500+
482501
// tests sql-apis ability to serialize projections and joins
483502
@Test
484503
public void sparkInnerJoin() throws Exception {

0 commit comments

Comments
 (0)