Skip to content

Commit ea66604

Browse files
committed
make project functions serializable in the sql-api
1 parent 2c4f8ca commit ea66604

2 files changed

Lines changed: 24 additions & 23 deletions

File tree

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/ProjectMapFuncImpl.java

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.ArrayList;
2222
import java.util.List;
2323
import java.util.function.BinaryOperator;
24+
import java.util.stream.Collectors;
2425

2526
import org.apache.calcite.rex.RexCall;
2627
import org.apache.calcite.rex.RexInputRef;
@@ -30,36 +31,30 @@
3031
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
3132

3233
import org.apache.wayang.core.function.FunctionDescriptor;
34+
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
3335
import org.apache.wayang.basic.data.Record;
3436

35-
3637
public class ProjectMapFuncImpl implements
3738
FunctionDescriptor.SerializableFunction<Record, Record> {
38-
private final List<RexNode> projects;
39+
final List<SerializableFunction<Record, Object>> projections;
3940

4041
public ProjectMapFuncImpl(final List<RexNode> projects) {
41-
this.projects = projects;
42-
}
43-
44-
@Override
45-
public Record apply(final Record record) {
46-
47-
final List<Object> projectedRecord = new ArrayList<>();
48-
for (int i = 0; i < projects.size(); i++) {
49-
final RexNode exp = projects.get(i);
42+
this.projections = projects.stream().map(exp -> {
5043
if (exp instanceof RexInputRef) {
51-
projectedRecord.add(record.getField(((RexInputRef) exp).getIndex()));
44+
final int key = ((RexInputRef) exp).getIndex();
45+
return (SerializableFunction<Record, Object>) record -> record.getField(key);
5246
} else if (exp instanceof RexLiteral) {
53-
final RexLiteral literal = (RexLiteral) exp;
54-
projectedRecord.add(literal.getValue());
47+
final Object literalValue = ((RexLiteral) exp).getValue();
48+
return (SerializableFunction<Record, Object>) record -> literalValue;
5549
} else if (exp instanceof RexCall) {
56-
projectedRecord.add(evaluateRexCall(record, (RexCall) exp));
50+
return (SerializableFunction<Record, Object>) record -> evaluateRexCall(record, (RexCall) exp);
51+
} else {
52+
throw new UnsupportedOperationException("Could not resolve record for exp: " + exp);
5753
}
58-
}
59-
return new Record(projectedRecord.toArray(new Object[0]));
54+
}).collect(Collectors.toList());
6055
}
6156

62-
public static Object evaluateRexCall(final Record record, final RexCall rexCall) {
57+
public Object evaluateRexCall(final Record record, final RexCall rexCall) {
6358
if (rexCall == null) {
6459
return null;
6560
}
@@ -70,7 +65,7 @@ public static Object evaluateRexCall(final Record record, final RexCall rexCall)
7065

7166
if (operator == SqlStdOperatorTable.PLUS) {
7267
// Handle addition
73-
return evaluateNaryOperation(record, operands, Double::sum);
68+
return evaluateNaryOperation(record, operands, (a, b) -> a + b);
7469
} else if (operator == SqlStdOperatorTable.MINUS) {
7570
// Handle subtraction
7671
return evaluateNaryOperation(record, operands, (a, b) -> a - b);
@@ -85,7 +80,7 @@ public static Object evaluateRexCall(final Record record, final RexCall rexCall)
8580
}
8681
}
8782

88-
public static Object evaluateNaryOperation(final Record record, final List<RexNode> operands,
83+
public Object evaluateNaryOperation(final Record record, final List<RexNode> operands,
8984
final BinaryOperator<Double> operation) {
9085
if (operands.isEmpty()) {
9186
return null;
@@ -110,7 +105,7 @@ public static Object evaluateNaryOperation(final Record record, final List<RexNo
110105
return result;
111106
}
112107

113-
public static Object evaluateRexNode(final Record record, final RexNode rexNode) {
108+
public Object evaluateRexNode(final Record record, final RexNode rexNode) {
114109
if (rexNode instanceof RexCall) {
115110
// Recursively evaluate a RexCall
116111
return evaluateRexCall(record, (RexCall) rexNode);
@@ -124,4 +119,9 @@ public static Object evaluateRexNode(final Record record, final RexNode rexNode)
124119
return null; // Unsupported or unknown expression
125120
}
126121
}
122+
123+
@Override
124+
public Record apply(final Record record) {
125+
return new Record(projections.stream().map(func -> func.apply(record)).toArray());
126+
}
127127
}

wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,9 @@ public void joinWithLargeLeftTableIndexMirrorAlias() throws Exception {
383383
assert (resultTally.equals(shouldBeTally));
384384
}
385385

386-
//@Test
387-
public void flinkInnerJoin() throws Exception {
386+
// tests sql-apis ability to serialize projections and joins
387+
@Test
388+
public void sparkInnerJoin() throws Exception {
388389
final SqlContext sqlContext = createSqlContext("/data/largeLeftTableIndex.csv");
389390

390391
final Tuple2<Collection<Record>, WayangPlan> t = this.buildCollectorAndWayangPlan(sqlContext,

0 commit comments

Comments
 (0)