Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,110 +18,144 @@

package org.apache.wayang.api.sql.calcite.converter.functions;

import java.util.ArrayList;
import java.io.Serializable;
import java.util.List;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;

import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.SqlKind;

import org.apache.wayang.core.function.FunctionDescriptor;
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
import org.apache.wayang.core.function.FunctionDescriptor.SerializableBiFunction;
import org.apache.wayang.basic.data.Record;

public class ProjectMapFuncImpl implements
FunctionDescriptor.SerializableFunction<Record, Record> {
final List<SerializableFunction<Record, Object>> projections;

public ProjectMapFuncImpl(final List<RexNode> projects) {
this.projections = projects.stream().map(exp -> {
if (exp instanceof RexInputRef) {
final int key = ((RexInputRef) exp).getIndex();
return (SerializableFunction<Record, Object>) record -> record.getField(key);
} else if (exp instanceof RexLiteral) {
final Object literalValue = ((RexLiteral) exp).getValue();
return (SerializableFunction<Record, Object>) record -> literalValue;
} else if (exp instanceof RexCall) {
return (SerializableFunction<Record, Object>) record -> evaluateRexCall(record, (RexCall) exp);
/**
* Serializable representation of {@link RexNode}
*/
abstract class Node implements Serializable {
public Node transform(final RexNode node) {
if (node instanceof RexCall) {
return new Call((RexCall) node);
} else if (node instanceof RexInputRef) {
return new InputRef((RexInputRef) node);
} else if (node instanceof RexLiteral) {
return new Literal((RexLiteral) node);
} else {
throw new UnsupportedOperationException("Could not resolve record for exp: " + exp);
throw new UnsupportedOperationException("RexNode not supported in projection function: " + node);
}
}).collect(Collectors.toList());
}

abstract Object evaluate(Record record);
}

public Object evaluateRexCall(final Record record, final RexCall rexCall) {
if (rexCall == null) {
return null;
/**
* Serializable representation of {@link RexCall}
*/
class Call extends Node {
final SerializableBiFunction<Number, Number, Number> operation;
final List<Node> children;

public Call(final RexCall call) {
this.operation = deriveOperation(call.getOperator().getKind());
this.children = call.getOperands().stream()
.map(op -> this.transform(op))
.collect(Collectors.toList());
}

// Get the operator and operands
final SqlOperator operator = rexCall.getOperator();
final List<RexNode> operands = rexCall.getOperands();

if (operator == SqlStdOperatorTable.PLUS) {
// Handle addition
return evaluateNaryOperation(record, operands, (a, b) -> a + b);
} else if (operator == SqlStdOperatorTable.MINUS) {
// Handle subtraction
return evaluateNaryOperation(record, operands, (a, b) -> a - b);
} else if (operator == SqlStdOperatorTable.MULTIPLY) {
// Handle multiplication
return evaluateNaryOperation(record, operands, (a, b) -> a * b);
} else if (operator == SqlStdOperatorTable.DIVIDE) {
// Handle division
return evaluateNaryOperation(record, operands, (a, b) -> a / b);
} else {
return null;
public Object evaluate(final Record record) {
assert (children.size() == 2) : "Project func call should only have two children";
return operation.apply((Number) children.get(0).evaluate(record),
(Number) children.get(1).evaluate(record));
}
}

public Object evaluateNaryOperation(final Record record, final List<RexNode> operands,
final BinaryOperator<Double> operation) {
if (operands.isEmpty()) {
return null;
/**
* Derives the java operator for a given {@link SqlKind}, and turns it into a serializable function
* @param kind {@link SqlKind} from {@link RexCall} SqlOperator
* @return a serializable function of +, -, * or /
* @throws UnsupportedOperationException on unrecognized {@link SqlKind}
*/
static SerializableBiFunction<Number, Number, Number> deriveOperation(final SqlKind kind) {
return (a, b) -> {
final double l = a.doubleValue();
final double r = b.doubleValue();
switch (kind) {
case PLUS:
return l + r;
case MINUS:
return l - r;
case TIMES:
return l * r;
case DIVIDE:
return l / r;
default:
throw new UnsupportedOperationException(
"Operation not supported in projection function RexCall: " + kind);
}
};
}
}

final List<Double> values = new ArrayList<>();
/**
* Serializable representation of {@link RexLiteral}
*/
class Literal extends Node {
final Comparable<?> value;

for (int i = 0; i < operands.size(); i++) {
final Number val = (Number) evaluateRexNode(record, operands.get(i));
if (val == null) {
return null;
}
values.add(val.doubleValue());
Literal(final RexLiteral literal) {
this.value = literal.getValueAs(Double.class);
}

Object result = values.get(0);
// Perform the operation with the remaining operands
for (int i = 1; i < operands.size(); i++) {
result = operation.apply((double) result, values.get(i));
@Override
public Object evaluate(final Record record) {
return value;
}

return result;
}

public Object evaluateRexNode(final Record record, final RexNode rexNode) {
if (rexNode instanceof RexCall) {
// Recursively evaluate a RexCall
return evaluateRexCall(record, (RexCall) rexNode);
} else if (rexNode instanceof RexLiteral) {
// Handle literals (e.g., numbers)
final RexLiteral literal = (RexLiteral) rexNode;
return literal.getValue();
} else if (rexNode instanceof RexInputRef) {
return record.getField(((RexInputRef) rexNode).getIndex());
} else {
return null; // Unsupported or unknown expression
/**
* Serializable representation of {@link InputRef}
*/
class InputRef extends Node {
final int key;

InputRef(final RexInputRef inputRef) {
this.key = inputRef.getIndex();
}

@Override
public Object evaluate(final Record record) {
return record.getField(key);
}
}

/**
* AST of the {@link RexCall} arithmetic, composed into serializable nodes; {@link Call}, {@link InputRef}, {@link Literal}
*/
final List<Node> projectionSyntaxTrees;

public ProjectMapFuncImpl(final List<RexNode> projects) {
this.projectionSyntaxTrees = projects.stream()
.map(projection -> {
if (projection instanceof RexCall) {
return new Call((RexCall) projection);
} else if (projection instanceof RexLiteral) {
return new Literal((RexLiteral) projection);
} else if (projection instanceof RexInputRef) {
return new InputRef((RexInputRef) projection);
} else {
throw new UnsupportedOperationException("RexNode not supported in projection: " + projection);
}
})
.collect(Collectors.toList());
}

@Override
public Record apply(final Record record) {
return new Record(projections.stream().map(func -> func.apply(record)).toArray());
return new Record(projectionSyntaxTrees.stream().map(call -> call.evaluate(record)).toArray());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.externalize.RelWriterImpl;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.type.SqlTypeName;
Expand All @@ -34,6 +37,7 @@

import org.apache.wayang.api.sql.calcite.convention.WayangConvention;
import org.apache.wayang.api.sql.calcite.converter.functions.FilterPredicateImpl;
import org.apache.wayang.api.sql.calcite.converter.functions.ProjectMapFuncImpl;
import org.apache.wayang.api.sql.calcite.optimizer.Optimizer;
import org.apache.wayang.api.sql.calcite.rules.WayangRules;
import org.apache.wayang.api.sql.calcite.schema.SchemaUtils;
Expand Down Expand Up @@ -69,6 +73,7 @@
import java.io.StringWriter;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -529,6 +534,44 @@ public void sparkInnerJoin() throws Exception {
assert (resultTally.equals(shouldBeTally));
}

@Test
public void serializeProjection() throws Exception {
final RexBuilder rb = new RexBuilder(new JavaTypeFactoryImpl());

final RelDataTypeFactory typeFactory = rb.getTypeFactory();
final RelDataType intType = typeFactory.createSqlType(SqlTypeName.INTEGER);
final RelDataType rowType = typeFactory.createStructType(
Arrays.asList(intType, intType, intType),
Arrays.asList("x", "b", "y"));

final RexNode inputRefX = rb.makeInputRef(rowType, 0);
final RexNode inputRefB = rb.makeInputRef(rowType, 1);
final RexNode inputRefY = rb.makeInputRef(rowType, 2);
final SqlOperator add = SqlStdOperatorTable.PLUS;
final SqlOperator multiply = SqlStdOperatorTable.MULTIPLY;

final RexNode addition = rb.makeCall(add, List.of(inputRefX, inputRefB));
final RexNode multiplication = rb.makeCall(multiply, List.of(addition, inputRefY));

final RexCall projection = (RexCall) multiplication;

final ProjectMapFuncImpl impl = new ProjectMapFuncImpl(List.of(projection));

final ByteArrayOutputStream byteOutStream = new ByteArrayOutputStream();
final ObjectOutputStream outStream = new ObjectOutputStream(byteOutStream);
outStream.writeObject(impl);
outStream.close();

final ByteArrayInputStream byteInStream = new ByteArrayInputStream(byteOutStream.toByteArray());
final ObjectInputStream inStream = new ObjectInputStream(byteInStream);
final ProjectMapFuncImpl deserializedImpl = (ProjectMapFuncImpl) inStream.readObject();
inStream.close();

final Record testRecord = new Record(1,2,3);

assert (impl.apply(testRecord).equals(deserializedImpl.apply(testRecord)));
}

// @Test
public void rexSerializationTest() throws Exception {
// create filterPredicateImpl for serialisation
Expand Down
Loading