Skip to content

Commit a4dab0e

Browse files
Merge branch 'main' into migrate_to_junit5
# Conflicts: # wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java
2 parents 651440b + 1788e1b commit a4dab0e

10 files changed

Lines changed: 683 additions & 603 deletions

File tree

wayang-api/wayang-api-sql/pom.xml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,26 @@
2626
</parent>
2727
<modelVersion>4.0.0</modelVersion>
2828

29+
<properties>
30+
<calcite.version>1.39.0</calcite.version>
31+
</properties>
32+
2933
<artifactId>wayang-api-sql</artifactId>
3034
<dependencies>
3135
<dependency>
3236
<groupId>org.apache.calcite</groupId>
3337
<artifactId>calcite-core</artifactId>
34-
<version>1.32.0</version>
38+
<version>${calcite.version}</version>
3539
</dependency>
3640
<dependency>
3741
<groupId>org.apache.calcite</groupId>
3842
<artifactId>calcite-linq4j</artifactId>
39-
<version>1.32.0</version>
43+
<version>${calcite.version}</version>
4044
</dependency>
4145
<dependency>
4246
<groupId>org.apache.calcite</groupId>
4347
<artifactId>calcite-file</artifactId>
44-
<version>1.29.0</version>
48+
<version>${calcite.version}</version>
4549
</dependency>
4650
<dependency>
4751
<groupId>org.apache.wayang</groupId>

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package org.apache.wayang.api.sql.calcite.converter;
2020

21+
import java.util.HashSet;
2122
import java.util.List;
2223
import java.util.Set;
2324

@@ -51,7 +52,7 @@ Operator visit(final WayangAggregate wayangRelNode) {
5152

5253
final List<AggregateCall> aggregateCalls = ((Aggregate) wayangRelNode).getAggCallList();
5354
final int groupCount = wayangRelNode.getGroupCount();
54-
final Set<Integer> groupingFields = wayangRelNode.getGroupSet().asSet();
55+
final HashSet<Integer> groupingFields = new HashSet<>(wayangRelNode.getGroupSet().asSet());
5556

5657
final MapOperator<Record, Record> mapOperator = new MapOperator<>(
5758
new AggregateAddCols(aggregateCalls),

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateAddCols.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,37 +18,42 @@
1818
package org.apache.wayang.api.sql.calcite.converter.functions;
1919

2020
import java.util.List;
21+
import java.util.stream.Collectors;
2122

2223
import org.apache.calcite.rel.core.AggregateCall;
23-
24+
import org.apache.calcite.sql.SqlKind;
2425
import org.apache.wayang.basic.data.Record;
26+
import org.apache.wayang.basic.data.Tuple2;
2527
import org.apache.wayang.core.function.FunctionDescriptor;
2628

2729
public class AggregateAddCols implements FunctionDescriptor.SerializableFunction<Record, Record> {
28-
final List<AggregateCall> aggregateCalls;
29-
30-
public AggregateAddCols(final List<AggregateCall> aggregateCalls){
31-
this.aggregateCalls = aggregateCalls;
30+
final List<Tuple2<SqlKind, List<Integer>>> aggregateCalls;
31+
32+
public AggregateAddCols(final List<AggregateCall> aggregateCalls) {
33+
this.aggregateCalls = aggregateCalls.stream().map(call -> new Tuple2<>(call.getAggregation().getKind(), call.getArgList())).collect(Collectors.toList());
3234
}
3335

3436
@Override
3537
public Record apply(final Record record) {
3638
final int l = record.size();
3739
final int newRecordSize = l + aggregateCalls.size() + 1;
3840
final Object[] resValues = new Object[newRecordSize];
39-
41+
4042
for (int i = 0; i < l; i++) {
4143
resValues[i] = record.getField(i);
4244
}
4345

4446
int i = l;
45-
for (final AggregateCall aggregateCall : aggregateCalls) {
46-
switch (aggregateCall.getAggregation().kind) {
47+
for (final Tuple2<SqlKind, List<Integer>> aggregateCall : aggregateCalls) {
48+
final SqlKind kind = aggregateCall.field0;
49+
final List<Integer> argList = aggregateCall.field1;
50+
51+
switch (kind) {
4752
case COUNT:
4853
resValues[i] = 1;
4954
break;
5055
default:
51-
resValues[i] = record.getField(aggregateCall.getArgList().get(0));
56+
resValues[i] = record.getField(argList.get(0));
5257
}
5358
i++;
5459
}

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateFunction.java

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,40 +17,44 @@
1717
*/
1818
package org.apache.wayang.api.sql.calcite.converter.functions;
1919

20+
import java.util.Arrays;
2021
import java.util.List;
2122
import java.util.Optional;
2223
import java.util.function.BiFunction;
24+
import java.util.stream.Collectors;
2325

2426
import org.apache.calcite.rel.core.AggregateCall;
2527
import org.apache.calcite.runtime.SqlFunctions;
28+
import org.apache.calcite.sql.SqlKind;
2629
import org.apache.wayang.basic.data.Record;
2730
import org.apache.wayang.core.function.FunctionDescriptor;
2831

2932
public class AggregateFunction
3033
implements FunctionDescriptor.SerializableBinaryOperator<Record> {
31-
32-
final List<AggregateCall> aggregateCalls;
34+
final List<SqlKind> aggregateKinds;
3335

3436
public AggregateFunction(final List<AggregateCall> aggregateCalls) {
35-
this.aggregateCalls = aggregateCalls;
37+
this.aggregateKinds = aggregateCalls.stream()
38+
.map(call -> call.getAggregation().getKind())
39+
.collect(Collectors.toList());
3640
}
3741

3842
@Override
3943
public Record apply(final Record record1, final Record record2) {
4044
final int l = record1.size();
4145
final Object[] resValues = new Object[l];
42-
final boolean countDone = false;
46+
boolean countDone = false;
4347

44-
for (int i = 0; i < l - aggregateCalls.size() - 1; i++) {
48+
for (int i = 0; i < l - aggregateKinds.size() - 1; i++) {
4549
resValues[i] = record1.getField(i);
4650
}
4751

48-
int counter = l - aggregateCalls.size() - 1;
49-
for (final AggregateCall aggregateCall : aggregateCalls) {
52+
int counter = l - aggregateKinds.size() - 1;
53+
for (final SqlKind kind : aggregateKinds) {
5054
final Object field1 = record1.getField(counter);
5155
final Object field2 = record2.getField(counter);
5256

53-
switch (aggregateCall.getAggregation().kind) {
57+
switch (kind) {
5458
case SUM:
5559
resValues[counter] = this.castAndMap(field1, field2, null, Long::sum, Integer::sum, Double::sum);
5660
break;
@@ -61,25 +65,33 @@ public Record apply(final Record record1, final Record record2) {
6165
case MAX:
6266
resValues[counter] = this.castAndMap(field1, field2, SqlFunctions::greatest, SqlFunctions::greatest,
6367
SqlFunctions::greatest, SqlFunctions::greatest);
68+
break;
6469
case COUNT:
6570
// since aggregates inject an extra column for counting before,
6671
// see AggregateAddCols. the column we operate on are integer counts,
6772
// which means we can eagerly get the fields as integers and simply sum
6873
assert (field1 instanceof Integer && field2 instanceof Integer)
6974
: "Expected to find integers for count but found: " + field1 + " and " + field2;
70-
Object obj = Integer.class.cast(field1) + Integer.class.cast(field2);
71-
resValues[counter] = obj;
75+
final Object count = Integer.class.cast(field1) + Integer.class.cast(field2);
76+
resValues[counter] = count;
7277
break;
7378
case AVG:
74-
throw new UnsupportedOperationException("Averages not currently supported");
75-
// resValues[counter] = this.castAndMap(field1, field2, null, null, null, null);
76-
// break;
79+
assert (field1 instanceof Integer && field2 instanceof Integer)
80+
: "Expected to find integers for count but found: " + field1 + " and " + field2;
81+
final Object avg = Integer.class.cast(field1) + Integer.class.cast(field2);
82+
83+
resValues[counter] = avg;
84+
85+
if (!countDone) {
86+
resValues[l - 1] = record1.getInt(l - 1) + record2.getInt(l - 1);
87+
countDone = true;
88+
}
89+
break;
7790
default:
78-
throw new IllegalStateException("Unsupported operation: " + aggregateCall.getAggregation().kind);
91+
throw new IllegalStateException("Unsupported operation: " + kind);
7992
}
8093
counter++;
8194
}
82-
8395
return new Record(resValues);
8496
}
8597

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateGetResult.java

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,49 +18,48 @@
1818

1919
package org.apache.wayang.api.sql.calcite.converter.functions;
2020

21+
import java.util.Arrays;
2122
import java.util.List;
2223
import java.util.Set;
24+
import java.util.stream.Collectors;
25+
import java.util.stream.IntStream;
26+
import java.util.stream.Stream;
2327

2428
import org.apache.calcite.rel.core.AggregateCall;
29+
import org.apache.calcite.sql.SqlKind;
2530
import org.apache.wayang.basic.data.Record;
2631
import org.apache.wayang.core.function.FunctionDescriptor;
2732

2833
public class AggregateGetResult implements FunctionDescriptor.SerializableFunction<Record, Record> {
29-
private final List<AggregateCall> aggregateCallList;
30-
private final Set<Integer> groupingfields;
34+
private final List<SqlKind> aggregateKindList;
35+
private final Set<Integer> groupingfields;
3136

32-
public AggregateGetResult(final List<AggregateCall> aggregateCalls, final Set<Integer> groupingfields) {
33-
this.aggregateCallList = aggregateCalls;
34-
this.groupingfields = groupingfields;
35-
}
37+
public AggregateGetResult(final List<AggregateCall> aggregateCalls, final Set<Integer> groupingfields) {
38+
this.aggregateKindList = aggregateCalls.stream()
39+
.map(call -> call.getAggregation().getKind())
40+
.collect(Collectors.toList());
41+
this.groupingfields = groupingfields;
42+
}
3643

37-
@Override
38-
public Record apply(final Record record) {
39-
final int l = record.size();
40-
final int outputRecordSize = aggregateCallList.size() + groupingfields.size();
41-
final Object[] resValues = new Object[outputRecordSize];
44+
@Override
45+
public Record apply(final Record record) {
46+
final int recordSize = record.size();
47+
final int aggregateCallOffset = recordSize - aggregateKindList.size() - 1;
4248

43-
int i = 0;
44-
int j = 0;
45-
for (i = 0; j < groupingfields.size(); i++) {
46-
if (groupingfields.contains(i)) {
47-
resValues[j] = record.getField(i);
48-
j++;
49-
}
50-
}
49+
final Object[] fields = groupingfields.stream()
50+
.map(record::getField)
51+
.toArray();
5152

52-
i = l - aggregateCallList.size() - 1;
53-
for (final AggregateCall aggregateCall : aggregateCallList) {
54-
final String name = aggregateCall.getAggregation().getName();
55-
if (name.equals("AVG")) {
56-
resValues[j] = record.getDouble(i) / record.getDouble(l - 1);
57-
} else {
58-
resValues[j] = record.getField(i);
59-
}
60-
j++;
61-
i++;
62-
}
53+
final Object[] aggregateCallFields = IntStream.range(0, aggregateKindList.size())
54+
.mapToObj(i -> aggregateKindList.get(i).equals(SqlKind.AVG)
55+
? record.getDouble(i + aggregateCallOffset)
56+
/ record.getDouble(recordSize - 1)
57+
: record.getField(i + aggregateCallOffset))
58+
.toArray();
6359

64-
return new Record(resValues);
65-
}
60+
final Object[] combinedFields = Stream.concat(Arrays.stream(fields), Arrays.stream(aggregateCallFields))
61+
.toArray();
62+
63+
return new Record(combinedFields);
64+
}
6665
}

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateKeyExtractor.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,16 @@
1919
package org.apache.wayang.api.sql.calcite.converter.functions;
2020

2121
import java.util.ArrayList;
22+
import java.util.HashSet;
2223
import java.util.List;
23-
import java.util.Set;
2424

2525
import org.apache.wayang.basic.data.Record;
2626
import org.apache.wayang.core.function.FunctionDescriptor;
2727

2828
public class AggregateKeyExtractor implements FunctionDescriptor.SerializableFunction<Record, Object> {
29-
private final Set<Integer> indexSet;
29+
private final HashSet<Integer> indexSet;
3030

31-
public AggregateKeyExtractor(final Set<Integer> indexSet) {
31+
public AggregateKeyExtractor(final HashSet<Integer> indexSet) {
3232
this.indexSet = indexSet;
3333
}
3434

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/FilterEvaluateCondition.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public boolean eval(final Record record, final SqlKind kind, final RexNode leftO
7878

7979
switch (kind) {
8080
case LIKE:
81-
return SqlFunctions.like(field.toString(), rexLiteral.toString().replace("'", ""));
81+
return like(field.toString(), rexLiteral.toString().replace("'", ""));
8282
case GREATER_THAN:
8383
return isGreaterThan(field, rexLiteral);
8484
case LESS_THAN:
@@ -127,6 +127,13 @@ public boolean eval(final Record record, final SqlKind kind, final RexNode leftO
127127
}
128128
}
129129

130+
private boolean like(final String s1, final String s2) {
131+
final SqlFunctions.LikeFunction likeFunction = new SqlFunctions.LikeFunction();
132+
final boolean isMatch = likeFunction.like(s1, s2);
133+
134+
return isMatch;
135+
}
136+
130137
private boolean isGreaterThan(final Object o, final RexLiteral rexLiteral) {
131138
// return rexLiteral.getValue().compareTo(o)< 0;
132139
return ((Comparable) o).compareTo(rexLiteral.getValueAs(o.getClass())) > 0;

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/SortFilter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ public SortFilter(final int fetch, final int offset) {
3939

4040
@Override
4141
public boolean test(final Record record) {
42-
final boolean test = increment >= offset && increment <= fetch;
4342
increment++;
44-
43+
final boolean test = increment >= offset && increment <= fetch;
44+
4545
return test;
4646
}
4747
}

0 commit comments

Comments
 (0)