Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.wayang.api.sql.calcite.converter;

import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -51,7 +52,7 @@ Operator visit(final WayangAggregate wayangRelNode) {

final List<AggregateCall> aggregateCalls = ((Aggregate) wayangRelNode).getAggCallList();
final int groupCount = wayangRelNode.getGroupCount();
final Set<Integer> groupingFields = wayangRelNode.getGroupSet().asSet();
final HashSet<Integer> groupingFields = new HashSet<>(wayangRelNode.getGroupSet().asSet());

final MapOperator<Record, Record> mapOperator = new MapOperator<>(
new AggregateAddCols(aggregateCalls),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
package org.apache.wayang.api.sql.calcite.converter.functions;

import java.util.List;
import java.util.stream.Collectors;

import org.apache.calcite.rel.core.AggregateCall;

import org.apache.calcite.sql.SqlKind;
import org.apache.wayang.basic.data.Record;
import org.apache.wayang.basic.data.Tuple2;
import org.apache.wayang.core.function.FunctionDescriptor;

public class AggregateAddCols implements FunctionDescriptor.SerializableFunction<Record, Record> {
final List<AggregateCall> aggregateCalls;
final List<Tuple2<SqlKind, List<Integer>>> aggregateCalls;

public AggregateAddCols(final List<AggregateCall> aggregateCalls) {
this.aggregateCalls = aggregateCalls;
this.aggregateCalls = aggregateCalls.stream().map(call -> new Tuple2<>(call.getAggregation().getKind(), call.getArgList())).collect(Collectors.toList());
}

@Override
Expand All @@ -42,13 +44,16 @@ public Record apply(final Record record) {
}

int i = l;
for (final AggregateCall aggregateCall : aggregateCalls) {
switch (aggregateCall.getAggregation().kind) {
for (final Tuple2<SqlKind, List<Integer>> aggregateCall : aggregateCalls) {
final SqlKind kind = aggregateCall.field0;
final List<Integer> argList = aggregateCall.field1;

switch (kind) {
case COUNT:
resValues[i] = 1;
break;
default:
resValues[i] = record.getField(aggregateCall.getArgList().get(0));
resValues[i] = record.getField(argList.get(0));
}
i++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,22 @@
import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.runtime.SqlFunctions;
import org.apache.calcite.sql.SqlKind;
import org.apache.wayang.basic.data.Record;
import org.apache.wayang.core.function.FunctionDescriptor;

public class AggregateFunction
implements FunctionDescriptor.SerializableBinaryOperator<Record> {

final List<AggregateCall> aggregateCalls;
final List<SqlKind> aggregateKinds;

public AggregateFunction(final List<AggregateCall> aggregateCalls) {
this.aggregateCalls = aggregateCalls;
this.aggregateKinds = aggregateCalls.stream()
.map(call -> call.getAggregation().getKind())
.collect(Collectors.toList());
}

@Override
Expand All @@ -42,16 +45,16 @@ public Record apply(final Record record1, final Record record2) {
final Object[] resValues = new Object[l];
boolean countDone = false;

for (int i = 0; i < l - aggregateCalls.size() - 1; i++) {
for (int i = 0; i < l - aggregateKinds.size() - 1; i++) {
resValues[i] = record1.getField(i);
}

int counter = l - aggregateCalls.size() - 1;
for (final AggregateCall aggregateCall : aggregateCalls) {
int counter = l - aggregateKinds.size() - 1;
for (final SqlKind kind : aggregateKinds) {
final Object field1 = record1.getField(counter);
final Object field2 = record2.getField(counter);

switch (aggregateCall.getAggregation().kind) {
switch (kind) {
case SUM:
resValues[counter] = this.castAndMap(field1, field2, null, Long::sum, Integer::sum, Double::sum);
break;
Expand Down Expand Up @@ -85,7 +88,7 @@ public Record apply(final Record record1, final Record record2) {
}
break;
default:
throw new IllegalStateException("Unsupported operation: " + aggregateCall.getAggregation().kind);
throw new IllegalStateException("Unsupported operation: " + kind);
}
counter++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

Expand All @@ -30,32 +31,35 @@
import org.apache.wayang.core.function.FunctionDescriptor;

public class AggregateGetResult implements FunctionDescriptor.SerializableFunction<Record, Record> {
private final List<AggregateCall> aggregateCallList;
private final Set<Integer> groupingfields;

public AggregateGetResult(final List<AggregateCall> aggregateCalls, final Set<Integer> groupingfields) {
this.aggregateCallList = aggregateCalls;
this.groupingfields = groupingfields;
}

@Override
public Record apply(final Record record) {
final int recordSize = record.size();
final int aggregateCallOffset = recordSize - aggregateCallList.size() - 1;

final Object[] fields = groupingfields.stream()
.map(record::getField)
.toArray();

final Object[] aggregateCallFields = IntStream.range(0, aggregateCallList.size())
.mapToObj(i -> aggregateCallList.get(i).getAggregation().getKind().equals(SqlKind.AVG)
? record.getDouble(i + aggregateCallOffset) / record.getDouble(recordSize - 1)
: record.getField(i + aggregateCallOffset))
.toArray();

final Object[] combinedFields = Stream.concat(Arrays.stream(fields), Arrays.stream(aggregateCallFields))
.toArray();

return new Record(combinedFields);
}
private final List<SqlKind> aggregateKindList;
private final Set<Integer> groupingfields;

public AggregateGetResult(final List<AggregateCall> aggregateCalls, final Set<Integer> groupingfields) {
this.aggregateKindList = aggregateCalls.stream()
.map(call -> call.getAggregation().getKind())
.collect(Collectors.toList());
this.groupingfields = groupingfields;
}

@Override
public Record apply(final Record record) {
final int recordSize = record.size();
final int aggregateCallOffset = recordSize - aggregateKindList.size() - 1;

final Object[] fields = groupingfields.stream()
.map(record::getField)
.toArray();

final Object[] aggregateCallFields = IntStream.range(0, aggregateKindList.size())
.mapToObj(i -> aggregateKindList.get(i).equals(SqlKind.AVG)
? record.getDouble(i + aggregateCallOffset)
/ record.getDouble(recordSize - 1)
: record.getField(i + aggregateCallOffset))
.toArray();

final Object[] combinedFields = Stream.concat(Arrays.stream(fields), Arrays.stream(aggregateCallFields))
.toArray();

return new Record(combinedFields);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
package org.apache.wayang.api.sql.calcite.converter.functions;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.wayang.basic.data.Record;
import org.apache.wayang.core.function.FunctionDescriptor;

public class AggregateKeyExtractor implements FunctionDescriptor.SerializableFunction<Record, Object> {
private final Set<Integer> indexSet;
private final HashSet<Integer> indexSet;

public AggregateKeyExtractor(final Set<Integer> indexSet) {
public AggregateKeyExtractor(final HashSet<Integer> indexSet) {
this.indexSet = indexSet;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ public SortFilter(final int fetch, final int offset) {

@Override
public boolean test(final Record record) {
final boolean test = increment >= offset && increment <= fetch;
increment++;

final boolean test = increment >= offset && increment <= fetch;

return test;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,27 @@ public void javaLimit() throws Exception {

final List<Record> result = r.stream().collect(Collectors.toList());

assert (result.size() == 1);
assert (result.get(0).equals(new Record(2, "a", "a", 2)));
}

@Test
public void javaLimitNoSort() throws Exception {
final SqlContext sqlContext = createSqlContext("/data/exampleSort.csv");

final Tuple2<Collection<Record>, WayangPlan> t = this.buildCollectorAndWayangPlan(sqlContext,
"SELECT col1, col2, col3 from fs.exampleSort LIMIT 2");

final Collection<Record> r = t.field0;
final WayangPlan wayangPlan = t.field1;

sqlContext.execute(wayangPlan);

final List<Record> result = r.stream().collect(Collectors.toList());

assert (result.size() == 2);
}

@Test
public void javaSort() throws Exception {
final SqlContext sqlContext = createSqlContext("/data/exampleSort.csv");
Expand Down Expand Up @@ -460,6 +478,26 @@ public void sparkFilter() throws Exception {
assert (result.stream().anyMatch(rec -> rec.equals(new Record("test1", "test1"))));
}

@Test
public void sparkAggregate() throws Exception {
final SqlContext sqlContext = this.createSqlContext("/data/largeLeftTableIndex.csv");
final Tuple2<Collection<Record>, WayangPlan> t = this.buildCollectorAndWayangPlan(sqlContext,
"SELECT largeLeftTableIndex.NAMEC, COUNT(*) FROM fs.largeLeftTableIndex GROUP BY NAMEC");
final Collection<Record> result = t.field0;
final WayangPlan wayangPlan = t.field1;

// except reduce by
PlanTraversal.upstream().traverse(wayangPlan.getSinks()).getTraversedNodes().forEach(node -> {
node.addTargetPlatform(Spark.platform());
});

sqlContext.execute(wayangPlan);

final Record rec = result.stream().findFirst().get();
assert (rec.size() == 2);
assert (rec.getInt(1) == 3);
}

// tests sql-apis ability to serialize projections and joins
@Test
public void sparkInnerJoin() throws Exception {
Expand Down
Loading
Loading