Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;

import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
Expand All @@ -30,36 +31,30 @@
import org.apache.calcite.sql.fun.SqlStdOperatorTable;

import org.apache.wayang.core.function.FunctionDescriptor;
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
import org.apache.wayang.basic.data.Record;


public class ProjectMapFuncImpl implements
FunctionDescriptor.SerializableFunction<Record, Record> {
private final List<RexNode> projects;
final List<SerializableFunction<Record, Object>> projections;

public ProjectMapFuncImpl(final List<RexNode> projects) {
this.projects = projects;
}

@Override
public Record apply(final Record record) {

final List<Object> projectedRecord = new ArrayList<>();
for (int i = 0; i < projects.size(); i++) {
final RexNode exp = projects.get(i);
this.projections = projects.stream().map(exp -> {
if (exp instanceof RexInputRef) {
projectedRecord.add(record.getField(((RexInputRef) exp).getIndex()));
final int key = ((RexInputRef) exp).getIndex();
return (SerializableFunction<Record, Object>) record -> record.getField(key);
} else if (exp instanceof RexLiteral) {
final RexLiteral literal = (RexLiteral) exp;
projectedRecord.add(literal.getValue());
final Object literalValue = ((RexLiteral) exp).getValue();
return (SerializableFunction<Record, Object>) record -> literalValue;
} else if (exp instanceof RexCall) {
projectedRecord.add(evaluateRexCall(record, (RexCall) exp));
return (SerializableFunction<Record, Object>) record -> evaluateRexCall(record, (RexCall) exp);
} else {
throw new UnsupportedOperationException("Could not resolve record for exp: " + exp);
}
}
return new Record(projectedRecord.toArray(new Object[0]));
}).collect(Collectors.toList());
}

public static Object evaluateRexCall(final Record record, final RexCall rexCall) {
public Object evaluateRexCall(final Record record, final RexCall rexCall) {
if (rexCall == null) {
return null;
}
Expand All @@ -70,7 +65,7 @@ public static Object evaluateRexCall(final Record record, final RexCall rexCall)

if (operator == SqlStdOperatorTable.PLUS) {
// Handle addition
return evaluateNaryOperation(record, operands, Double::sum);
return evaluateNaryOperation(record, operands, (a, b) -> a + b);
} else if (operator == SqlStdOperatorTable.MINUS) {
// Handle subtraction
return evaluateNaryOperation(record, operands, (a, b) -> a - b);
Expand All @@ -85,7 +80,7 @@ public static Object evaluateRexCall(final Record record, final RexCall rexCall)
}
}

public static Object evaluateNaryOperation(final Record record, final List<RexNode> operands,
public Object evaluateNaryOperation(final Record record, final List<RexNode> operands,
final BinaryOperator<Double> operation) {
if (operands.isEmpty()) {
return null;
Expand All @@ -110,7 +105,7 @@ public static Object evaluateNaryOperation(final Record record, final List<RexNo
return result;
}

public static Object evaluateRexNode(final Record record, final RexNode rexNode) {
public Object evaluateRexNode(final Record record, final RexNode rexNode) {
if (rexNode instanceof RexCall) {
// Recursively evaluate a RexCall
return evaluateRexCall(record, (RexCall) rexNode);
Expand All @@ -124,4 +119,9 @@ public static Object evaluateRexNode(final Record record, final RexNode rexNode)
return null; // Unsupported or unknown expression
}
}

@Override
public Record apply(final Record record) {
return new Record(projections.stream().map(func -> func.apply(record)).toArray());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,24 @@

package org.apache.wayang.api.sql;

import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.externalize.RelWriterImpl;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RuleSet;
import org.apache.calcite.tools.RuleSets;

import org.apache.wayang.api.sql.calcite.convention.WayangConvention;
import org.apache.wayang.api.sql.calcite.converter.functions.FilterPredicateImpl;
import org.apache.wayang.api.sql.calcite.optimizer.Optimizer;
import org.apache.wayang.api.sql.calcite.rules.WayangRules;
import org.apache.wayang.api.sql.calcite.schema.SchemaUtils;
Expand All @@ -41,18 +46,25 @@
import org.apache.wayang.api.sql.context.SqlContext;
import org.apache.wayang.basic.data.Tuple2;
import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
import org.apache.wayang.core.function.FunctionDescriptor.SerializablePredicate;
import org.apache.wayang.core.plan.wayangplan.Operator;
import org.apache.wayang.core.plan.wayangplan.PlanTraversal;
import org.apache.wayang.core.plan.wayangplan.WayangPlan;
import org.apache.wayang.java.Java;
import org.apache.wayang.spark.Spark;
import org.apache.wayang.basic.data.Record;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.json.simple.parser.ParseException;

import org.junit.Test;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.sql.SQLException;
Expand Down Expand Up @@ -371,6 +383,62 @@ public void joinWithLargeLeftTableIndexMirrorAlias() throws Exception {
assert (resultTally.equals(shouldBeTally));
}

// tests sql-apis ability to serialize projections and joins
@Test
public void sparkInnerJoin() throws Exception {
final SqlContext sqlContext = createSqlContext("/data/largeLeftTableIndex.csv");

final Tuple2<Collection<Record>, WayangPlan> t = this.buildCollectorAndWayangPlan(sqlContext,
"SELECT * FROM fs.largeLeftTableIndex AS na INNER JOIN fs.largeLeftTableIndex AS nb ON nb.NAMEB = na.NAMEA " //
);

final Collection<Record> result = t.field0;
final WayangPlan wayangPlan = t.field1;

PlanTraversal.upstream().traverse(wayangPlan.getSinks()).getTraversedNodes().forEach(node -> {
node.addTargetPlatform(Spark.platform());
});

sqlContext.execute(wayangPlan);

final List<Record> shouldBe = List.of(
new Record("test1", "test1", "test2", "test1", "test1", "test2"),
new Record("test2", "" , "test2", "" , "test2", "test2"),
new Record("" , "test2", "test2", "test2", "" , "test2")
);

final Map<Record, Integer> resultTally = result.stream()
.collect(Collectors.toMap(rec -> rec, rec -> 1, Integer::sum));
final Map<Record, Integer> shouldBeTally = shouldBe.stream()
.collect(Collectors.toMap(rec -> rec, rec -> 1, Integer::sum));

assert (resultTally.equals(shouldBeTally));
}

//@Test
public void rexSerializationTest() throws Exception {
// create filterPredicateImpl for serialisation
final RelDataTypeFactory typeFactory = new JavaTypeFactoryImpl();
final RexBuilder rb = new RexBuilder(typeFactory);
final RexNode leftOperand = rb.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0);
final RexNode rightOperand = rb.makeLiteral("test");
final RexNode cond = rb.makeCall(SqlStdOperatorTable.EQUALS, leftOperand, rightOperand);
final SerializablePredicate<?> fpImpl = new FilterPredicateImpl(cond);

final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
final ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
objectOutputStream.writeObject(fpImpl);
objectOutputStream.close();

final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(
byteArrayOutputStream.toByteArray());
final ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
final Object deserializedObject = objectInputStream.readObject();
objectInputStream.close();

assert (((FilterPredicateImpl) deserializedObject).test(new Record("test")));
}

@Test
public void exampleFilterTableRefToTableRef() throws Exception {
final SqlContext sqlContext = createSqlContext("/data/exampleRefToRef.csv");
Expand Down
Loading