Skip to content

Commit 58cfcbb

Browse files
authored
Merge pull request #564 from mspruc/main
fix serialization for RexCalls in sql-api projections
2 parents 9d71386 + af8aefe commit 58cfcbb

2 files changed

Lines changed: 147 additions & 70 deletions

File tree

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/ProjectMapFuncImpl.java

Lines changed: 104 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,110 +18,144 @@
1818

1919
package org.apache.wayang.api.sql.calcite.converter.functions;
2020

21-
import java.util.ArrayList;
21+
import java.io.Serializable;
2222
import java.util.List;
23-
import java.util.function.BinaryOperator;
2423
import java.util.stream.Collectors;
2524

2625
import org.apache.calcite.rex.RexCall;
2726
import org.apache.calcite.rex.RexInputRef;
2827
import org.apache.calcite.rex.RexLiteral;
2928
import org.apache.calcite.rex.RexNode;
30-
import org.apache.calcite.sql.SqlOperator;
31-
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
29+
import org.apache.calcite.sql.SqlKind;
3230

3331
import org.apache.wayang.core.function.FunctionDescriptor;
34-
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
32+
import org.apache.wayang.core.function.FunctionDescriptor.SerializableBiFunction;
3533
import org.apache.wayang.basic.data.Record;
3634

3735
public class ProjectMapFuncImpl implements
3836
FunctionDescriptor.SerializableFunction<Record, Record> {
39-
final List<SerializableFunction<Record, Object>> projections;
4037

41-
public ProjectMapFuncImpl(final List<RexNode> projects) {
42-
this.projections = projects.stream().map(exp -> {
43-
if (exp instanceof RexInputRef) {
44-
final int key = ((RexInputRef) exp).getIndex();
45-
return (SerializableFunction<Record, Object>) record -> record.getField(key);
46-
} else if (exp instanceof RexLiteral) {
47-
final Object literalValue = ((RexLiteral) exp).getValue();
48-
return (SerializableFunction<Record, Object>) record -> literalValue;
49-
} else if (exp instanceof RexCall) {
50-
return (SerializableFunction<Record, Object>) record -> evaluateRexCall(record, (RexCall) exp);
38+
/**
39+
* Serializable representation of {@link RexNode}
40+
*/
41+
abstract class Node implements Serializable {
42+
public Node transform(final RexNode node) {
43+
if (node instanceof RexCall) {
44+
return new Call((RexCall) node);
45+
} else if (node instanceof RexInputRef) {
46+
return new InputRef((RexInputRef) node);
47+
} else if (node instanceof RexLiteral) {
48+
return new Literal((RexLiteral) node);
5149
} else {
52-
throw new UnsupportedOperationException("Could not resolve record for exp: " + exp);
50+
throw new UnsupportedOperationException("RexNode not supported in projection function: " + node);
5351
}
54-
}).collect(Collectors.toList());
52+
}
53+
54+
abstract Object evaluate(Record record);
5555
}
5656

57-
public Object evaluateRexCall(final Record record, final RexCall rexCall) {
58-
if (rexCall == null) {
59-
return null;
57+
/**
58+
* Serializable representation of {@link RexCall}
59+
*/
60+
class Call extends Node {
61+
final SerializableBiFunction<Number, Number, Number> operation;
62+
final List<Node> children;
63+
64+
public Call(final RexCall call) {
65+
this.operation = deriveOperation(call.getOperator().getKind());
66+
this.children = call.getOperands().stream()
67+
.map(op -> this.transform(op))
68+
.collect(Collectors.toList());
6069
}
6170

62-
// Get the operator and operands
63-
final SqlOperator operator = rexCall.getOperator();
64-
final List<RexNode> operands = rexCall.getOperands();
65-
66-
if (operator == SqlStdOperatorTable.PLUS) {
67-
// Handle addition
68-
return evaluateNaryOperation(record, operands, (a, b) -> a + b);
69-
} else if (operator == SqlStdOperatorTable.MINUS) {
70-
// Handle subtraction
71-
return evaluateNaryOperation(record, operands, (a, b) -> a - b);
72-
} else if (operator == SqlStdOperatorTable.MULTIPLY) {
73-
// Handle multiplication
74-
return evaluateNaryOperation(record, operands, (a, b) -> a * b);
75-
} else if (operator == SqlStdOperatorTable.DIVIDE) {
76-
// Handle division
77-
return evaluateNaryOperation(record, operands, (a, b) -> a / b);
78-
} else {
79-
return null;
71+
public Object evaluate(final Record record) {
72+
assert (children.size() == 2) : "Project func call should only have two children";
73+
return operation.apply((Number) children.get(0).evaluate(record),
74+
(Number) children.get(1).evaluate(record));
8075
}
81-
}
8276

83-
public Object evaluateNaryOperation(final Record record, final List<RexNode> operands,
84-
final BinaryOperator<Double> operation) {
85-
if (operands.isEmpty()) {
86-
return null;
77+
/**
78+
* Derives the java operator for a given {@link SqlKind}, and turns it into a serializable function
79+
* @param kind {@link SqlKind} from {@link RexCall} SqlOperator
80+
* @return a serializable function of +, -, * or /
81+
* @throws UnsupportedOperationException on unrecognized {@link SqlKind}
82+
*/
83+
static SerializableBiFunction<Number, Number, Number> deriveOperation(final SqlKind kind) {
84+
return (a, b) -> {
85+
final double l = a.doubleValue();
86+
final double r = b.doubleValue();
87+
switch (kind) {
88+
case PLUS:
89+
return l + r;
90+
case MINUS:
91+
return l - r;
92+
case TIMES:
93+
return l * r;
94+
case DIVIDE:
95+
return l / r;
96+
default:
97+
throw new UnsupportedOperationException(
98+
"Operation not supported in projection function RexCall: " + kind);
99+
}
100+
};
87101
}
102+
}
88103

89-
final List<Double> values = new ArrayList<>();
104+
/**
105+
* Serializable representation of {@link RexLiteral}
106+
*/
107+
class Literal extends Node {
108+
final Comparable<?> value;
90109

91-
for (int i = 0; i < operands.size(); i++) {
92-
final Number val = (Number) evaluateRexNode(record, operands.get(i));
93-
if (val == null) {
94-
return null;
95-
}
96-
values.add(val.doubleValue());
110+
Literal(final RexLiteral literal) {
111+
this.value = literal.getValueAs(Double.class);
97112
}
98113

99-
Object result = values.get(0);
100-
// Perform the operation with the remaining operands
101-
for (int i = 1; i < operands.size(); i++) {
102-
result = operation.apply((double) result, values.get(i));
114+
@Override
115+
public Object evaluate(final Record record) {
116+
return value;
103117
}
104-
105-
return result;
106118
}
107119

108-
public Object evaluateRexNode(final Record record, final RexNode rexNode) {
109-
if (rexNode instanceof RexCall) {
110-
// Recursively evaluate a RexCall
111-
return evaluateRexCall(record, (RexCall) rexNode);
112-
} else if (rexNode instanceof RexLiteral) {
113-
// Handle literals (e.g., numbers)
114-
final RexLiteral literal = (RexLiteral) rexNode;
115-
return literal.getValue();
116-
} else if (rexNode instanceof RexInputRef) {
117-
return record.getField(((RexInputRef) rexNode).getIndex());
118-
} else {
119-
return null; // Unsupported or unknown expression
120+
/**
121+
* Serializable representation of {@link InputRef}
122+
*/
123+
class InputRef extends Node {
124+
final int key;
125+
126+
InputRef(final RexInputRef inputRef) {
127+
this.key = inputRef.getIndex();
128+
}
129+
130+
@Override
131+
public Object evaluate(final Record record) {
132+
return record.getField(key);
120133
}
121134
}
122135

136+
/**
137+
* AST of the {@link RexCall} arithmetic, composed into serializable nodes; {@link Call}, {@link InputRef}, {@link Literal}
138+
*/
139+
final List<Node> projectionSyntaxTrees;
140+
141+
public ProjectMapFuncImpl(final List<RexNode> projects) {
142+
this.projectionSyntaxTrees = projects.stream()
143+
.map(projection -> {
144+
if (projection instanceof RexCall) {
145+
return new Call((RexCall) projection);
146+
} else if (projection instanceof RexLiteral) {
147+
return new Literal((RexLiteral) projection);
148+
} else if (projection instanceof RexInputRef) {
149+
return new InputRef((RexInputRef) projection);
150+
} else {
151+
throw new UnsupportedOperationException("RexNode not supported in projection: " + projection);
152+
}
153+
})
154+
.collect(Collectors.toList());
155+
}
156+
123157
@Override
124158
public Record apply(final Record record) {
125-
return new Record(projections.stream().map(func -> func.apply(record)).toArray());
159+
return new Record(projectionSyntaxTrees.stream().map(call -> call.evaluate(record)).toArray());
126160
}
127161
}

wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
import org.apache.calcite.rel.RelNode;
2222
import org.apache.calcite.rel.externalize.RelWriterImpl;
2323
import org.apache.calcite.rel.rules.CoreRules;
24+
import org.apache.calcite.rel.type.RelDataType;
2425
import org.apache.calcite.rel.type.RelDataTypeFactory;
2526
import org.apache.calcite.rex.RexBuilder;
27+
import org.apache.calcite.rex.RexCall;
2628
import org.apache.calcite.rex.RexNode;
2729
import org.apache.calcite.sql.SqlExplainLevel;
2830
import org.apache.calcite.sql.SqlNode;
31+
import org.apache.calcite.sql.SqlOperator;
2932
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
3033
import org.apache.calcite.sql.parser.SqlParseException;
3134
import org.apache.calcite.sql.type.SqlTypeName;
@@ -34,6 +37,7 @@
3437

3538
import org.apache.wayang.api.sql.calcite.convention.WayangConvention;
3639
import org.apache.wayang.api.sql.calcite.converter.functions.FilterPredicateImpl;
40+
import org.apache.wayang.api.sql.calcite.converter.functions.ProjectMapFuncImpl;
3741
import org.apache.wayang.api.sql.calcite.optimizer.Optimizer;
3842
import org.apache.wayang.api.sql.calcite.rules.WayangRules;
3943
import org.apache.wayang.api.sql.calcite.schema.SchemaUtils;
@@ -69,6 +73,7 @@
6973
import java.io.StringWriter;
7074
import java.sql.SQLException;
7175
import java.util.ArrayList;
76+
import java.util.Arrays;
7277
import java.util.Collection;
7378
import java.util.List;
7479
import java.util.Map;
@@ -529,6 +534,44 @@ public void sparkInnerJoin() throws Exception {
529534
assert (resultTally.equals(shouldBeTally));
530535
}
531536

537+
@Test
538+
public void serializeProjection() throws Exception {
539+
final RexBuilder rb = new RexBuilder(new JavaTypeFactoryImpl());
540+
541+
final RelDataTypeFactory typeFactory = rb.getTypeFactory();
542+
final RelDataType intType = typeFactory.createSqlType(SqlTypeName.INTEGER);
543+
final RelDataType rowType = typeFactory.createStructType(
544+
Arrays.asList(intType, intType, intType),
545+
Arrays.asList("x", "b", "y"));
546+
547+
final RexNode inputRefX = rb.makeInputRef(rowType, 0);
548+
final RexNode inputRefB = rb.makeInputRef(rowType, 1);
549+
final RexNode inputRefY = rb.makeInputRef(rowType, 2);
550+
final SqlOperator add = SqlStdOperatorTable.PLUS;
551+
final SqlOperator multiply = SqlStdOperatorTable.MULTIPLY;
552+
553+
final RexNode addition = rb.makeCall(add, List.of(inputRefX, inputRefB));
554+
final RexNode multiplication = rb.makeCall(multiply, List.of(addition, inputRefY));
555+
556+
final RexCall projection = (RexCall) multiplication;
557+
558+
final ProjectMapFuncImpl impl = new ProjectMapFuncImpl(List.of(projection));
559+
560+
final ByteArrayOutputStream byteOutStream = new ByteArrayOutputStream();
561+
final ObjectOutputStream outStream = new ObjectOutputStream(byteOutStream);
562+
outStream.writeObject(impl);
563+
outStream.close();
564+
565+
final ByteArrayInputStream byteInStream = new ByteArrayInputStream(byteOutStream.toByteArray());
566+
final ObjectInputStream inStream = new ObjectInputStream(byteInStream);
567+
final ProjectMapFuncImpl deserializedImpl = (ProjectMapFuncImpl) inStream.readObject();
568+
inStream.close();
569+
570+
final Record testRecord = new Record(1,2,3);
571+
572+
assert (impl.apply(testRecord).equals(deserializedImpl.apply(testRecord)));
573+
}
574+
532575
// @Test
533576
public void rexSerializationTest() throws Exception {
534577
// create filterPredicateImpl for serialisation

0 commit comments

Comments
 (0)