diff --git a/wayang-plugins/wayang-ml/.project b/wayang-plugins/wayang-ml/.project
new file mode 100644
index 000000000..9c6ad63ff
--- /dev/null
+++ b/wayang-plugins/wayang-ml/.project
@@ -0,0 +1,34 @@
+
+
+ wayang-ml
+
+
+
+
+
+ org.eclipse.jdt.core.javabuilder
+
+
+
+
+ org.eclipse.m2e.core.maven2Builder
+
+
+
+
+
+ org.eclipse.jdt.core.javanature
+ org.eclipse.m2e.core.maven2Nature
+
+
+
+ 1715926807014
+
+ 30
+
+ org.eclipse.core.resources.regexFilterMatcher
+ node_modules|\.git|__CREATED_BY_JAVA_LANGUAGE_SERVER__
+
+
+
+
diff --git a/wayang-plugins/wayang-ml/pom.xml b/wayang-plugins/wayang-ml/pom.xml
new file mode 100644
index 000000000..6187dd978
--- /dev/null
+++ b/wayang-plugins/wayang-ml/pom.xml
@@ -0,0 +1,157 @@
+
+
+
+
+ 4.0.0
+
+
+ org.apache.wayang
+ wayang-plugins
+ 0.7.1
+
+
+ wayang-ml
+ 0.7.1
+
+
+ org.apache.wayang.extensions.ml
+
+
+
+
+ org.apache.wayang
+ wayang-api-sql
+ 0.7.1
+
+
+ com.microsoft.onnxruntime
+ onnxruntime
+ 1.21.1
+
+
+
+ org.apache.wayang
+ wayang-core
+ 0.7.1
+
+
+ org.apache.wayang
+ wayang-basic
+ 0.7.1
+
+
+ org.apache.wayang
+ wayang-java
+ 0.7.1
+
+
+ org.apache.wayang
+ wayang-spark
+ 0.7.1
+
+
+ org.apache.wayang
+ wayang-flink
+ 0.7.1
+
+
+ org.apache.flink
+ flink-java
+ ${flink.version}
+
+
+ org.apache.wayang
+ wayang-giraph
+ 0.7.1
+
+
+ org.apache.wayang
+ wayang-generic-jdbc
+ 0.7.1
+
+
+ org.reflections
+ reflections
+ 0.10.2
+
+
+ org.apache.wayang
+ wayang-benchmark
+ 0.7.1
+
+
+ org.apache.wayang
+ wayang-api-python
+ 0.7.1
+
+
+ org.apache.commons
+ commons-dbcp2
+ 2.7.0
+
+
+ org.apache.spark
+ spark-core_2.12
+ ${spark.version}
+
+
+ org.apache.spark
+ spark-graphx_2.12
+ ${spark.version}
+
+
+ org.apache.spark
+ spark-mllib_2.12
+ ${spark.version}
+
+
+ com.google.protobuf
+ protobuf-java
+ 3.16.3
+
+
+ org.apache.calcite
+ calcite-core
+ ${calcite.version}
+
+
+ org.apache.calcite
+ calcite-linq4j
+ ${calcite.version}
+
+
+ org.apache.calcite
+ calcite-file
+ ${calcite.version}
+
+
+
+
+
+ src/main/resources
+
+
+
+
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MLContext.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MLContext.java
new file mode 100644
index 000000000..a0ed6c3c0
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MLContext.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml;
+
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.core.api.exception.WayangException;
+import org.apache.logging.log4j.Level;
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.api.Job;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.core.plan.executionplan.ExecutionPlan;
+import org.apache.wayang.core.optimizer.DefaultOptimizationContext;
+import org.apache.wayang.core.optimizer.OptimizationContext;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.ml.costs.PairwiseCost;
+import org.apache.wayang.ml.encoding.OneHotMappings;
+import org.apache.wayang.ml.encoding.OrtTensorEncoder;
+import org.apache.wayang.ml.encoding.TreeEncoder;
+import org.apache.wayang.ml.encoding.TreeNode;
+import org.apache.wayang.ml.util.EnumerationStrategy;
+import org.apache.wayang.ml.util.Logging;
+import org.apache.wayang.core.util.Tuple;
+import org.apache.logging.log4j.Level;
+
+import java.io.IOException;
+import java.io.BufferedWriter;
+import java.io.FileWriter;
+import java.time.Instant;
+import java.time.Duration;
+
+import java.util.ArrayList;
+import java.util.Optional;
+import java.util.Collection;
+
+/**
+ * This is the entry point for users to work with Wayang ML.
+ */
+public class MLContext extends WayangContext {
+
+ private OrtMLModel model;
+
+ private EnumerationStrategy enumerationStrategy = EnumerationStrategy.NONE;
+
+ public MLContext() {
+ super();
+ }
+
+ public MLContext(Configuration configuration) {
+ super(configuration);
+ }
+
+ /**
+ * Execute a plan.
+ *
+ * @param wayangPlan the plan to execute
+ * @param udfJars JARs that declare the code for the UDFs
+ * @see ReflectionUtils#getDeclaringJar(Class)
+ */
+ @Override
+ public void execute(WayangPlan wayangPlan, String... udfJars) {
+ this.setLogLevel(Level.ERROR);
+ Job wayangJob = this.createJob("", wayangPlan, udfJars);
+ OneHotMappings.setOptimizationContext(wayangJob.getOptimizationContext());
+
+ Configuration config = this.getConfiguration();
+ Configuration jobConfig = wayangJob.getConfiguration();
+
+ wayangJob.execute();
+
+ if (config.getBooleanProperty("wayang.ml.experience.enabled")) {
+ String original;
+
+ Optional originalOption = config.getOptionalStringProperty("wayang.ml.experience.original");
+ if (originalOption.isPresent()) {
+ original = originalOption.get();
+ } else {
+ original = TreeEncoder.encode(wayangPlan).toString();
+ }
+
+ String withChoices;
+
+ Optional choicesOption = config.getOptionalStringProperty("wayang.ml.experience.with-platforms");
+ if (choicesOption.isPresent()) {
+ withChoices = choicesOption.get();
+ } else {
+ withChoices = jobConfig.getStringProperty("wayang.ml.experience.with-platforms");
+ }
+
+ long execTime = jobConfig.getLongProperty("wayang.ml.experience.exec-time");
+
+ this.logExperience(original, withChoices, execTime);
+ }
+ }
+
+ public void executeVAE(WayangPlan wayangPlan, String ...udfJars) {
+ this.setLogLevel(Level.ERROR);
+ try {
+ Job job = this.createJob("", wayangPlan, udfJars);
+ Configuration jobConfig = job.getConfiguration();
+ //job.prepareWayangPlan();
+ job.estimateKeyFigures();
+ OneHotMappings.setOptimizationContext(job.getOptimizationContext());
+ OneHotMappings.encodeIds = true;
+
+ // Log Encoding time
+ Instant start = Instant.now();
+ TreeNode wayangNode = TreeEncoder.encode(wayangPlan);
+
+ Instant end = Instant.now();
+ long execTime = Duration.between(start, end).toMillis();
+ Logging.writeToFile(
+ String.format("Encoding: %d", execTime),
+ this.getConfiguration().getStringProperty("wayang.ml.optimizations.file")
+ );
+ OrtMLModel model = OrtMLModel.getInstance(job.getConfiguration());
+ // Log inference time
+ start = Instant.now();
+ Tuple resultTuple = model.runVAE(wayangPlan, wayangNode);
+ end = Instant.now();
+ execTime = Duration.between(start, end).toMillis();
+
+ WayangPlan platformPlan = resultTuple.field0;
+
+ this.getConfiguration().setProperty(
+ "wayang.ml.experience.original",
+ wayangNode.toStringEncoding()
+ );
+
+ this.getConfiguration().setProperty(
+ "wayang.ml.experience.with-platforms",
+ resultTuple.field1.toString()
+ );
+
+ this.execute(platformPlan, udfJars);
+ } catch (Exception e) {
+ e.printStackTrace();
+ throw new WayangException("Executing WayangPlan with VAE model failed");
+ }
+ }
+
+ public ExecutionPlan buildWithVAE(WayangPlan wayangPlan, String ...udfJars) {
+ try {
+ Job job = this.createJob("", wayangPlan, udfJars);
+ job.estimateKeyFigures();
+ OneHotMappings.setOptimizationContext(job.getOptimizationContext());
+ OneHotMappings.encodeIds = true;
+
+ TreeNode wayangNode = TreeEncoder.encode(wayangPlan);
+ OrtMLModel model = OrtMLModel.getInstance(job.getConfiguration());
+ Tuple resultTuple = model.runVAE(wayangPlan, wayangNode);
+ WayangPlan platformPlan = resultTuple.field0;
+
+ return this.buildInitialExecutionPlan("", platformPlan, udfJars);
+ } catch (Exception e) {
+ e.printStackTrace();
+ throw new WayangException("Executing WayangPlan with VAE model failed");
+ }
+ }
+
+ public void setModel(OrtMLModel model) {
+ this.model = model;
+ }
+
+ private void logExperience(String original, String withChoices, long execTime) {
+ if (!this.getConfiguration().getBooleanProperty("wayang.ml.experience.enabled")) {
+ return;
+ }
+
+ String content = String.format("%s:%s:%d", original, withChoices, execTime);
+ Logging.writeToFile(
+ content,
+ this.getConfiguration().getStringProperty("wayang.ml.experience.file")
+ );
+ }
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MachineLearning.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MachineLearning.java
new file mode 100644
index 000000000..844d59c2c
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MachineLearning.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml;
+
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.mapping.Mapping;
+import org.apache.wayang.core.optimizer.channels.ChannelConversion;
+import org.apache.wayang.core.platform.Platform;
+import org.apache.wayang.core.plugin.Plugin;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+
+/**
+ * Provides {@link Plugin}s that enable usage of the xxxx.
+ */
+public class MachineLearning {
+
+ /**
+ * Enables use with the {@link JavaPlatform} and {@link SparkPlatform}.
+ */
+ private static final Plugin PLUGIN = new Plugin() {
+
+ @Override
+ public Collection getRequiredPlatforms() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public Collection getMappings() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public Collection getChannelConversions() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void setProperties(Configuration configuration) {
+ }
+ };
+
+ /**
+ * Retrieve a {@link Plugin} to use xxx on the
+ *
+ * @return the {@link Plugin}
+ */
+ public static Plugin plugin() {
+ return PLUGIN;
+ }
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/OrtMLModel.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/OrtMLModel.java
new file mode 100644
index 000000000..7c8fa5c7f
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/OrtMLModel.java
@@ -0,0 +1,442 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml;
+
+import ai.onnxruntime.NodeInfo;
+import ai.onnxruntime.OnnxTensor;
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtLoggingLevel;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import ai.onnxruntime.providers.OrtCUDAProviderOptions;
+import ai.onnxruntime.TensorInfo;
+import ai.onnxruntime.OrtSession.Result;
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.wayang.core.util.Tuple;
+import org.apache.wayang.ml.encoding.OrtTensorDecoder;
+import org.apache.wayang.ml.encoding.OrtTensorEncoder;
+import org.apache.wayang.ml.encoding.TreeDecoder;
+import org.apache.wayang.ml.encoding.TreeNode;
+import org.apache.wayang.ml.util.Logging;
+import org.apache.wayang.ml.validation.*;
+
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.Vector;
+import java.util.HashSet;
+import java.util.function.BiFunction;
+import java.util.stream.Stream;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.time.Instant;
+import java.time.Duration;
+
+public class OrtMLModel {
+
+ private static OrtMLModel INSTANCE;
+
+ private OrtSession session;
+ private OrtEnvironment env;
+ private Configuration configuration;
+
+ private final Map inputMap = new HashMap<>();
+ private final Set requestedOutputs = new HashSet<>();
+
+ public static OrtMLModel getInstance(Configuration configuration) throws OrtException {
+ if (INSTANCE == null) {
+ INSTANCE = new OrtMLModel(configuration);
+ }
+
+ return INSTANCE;
+ }
+
+ private OrtMLModel(Configuration configuration) throws OrtException {
+ this.configuration = configuration;
+ this.loadModel(configuration.getStringProperty("wayang.ml.model.file"));
+ }
+
+ private void loadModel(String filePath) throws OrtException {
+ if (this.env == null) {
+ this.env = OrtEnvironment.getEnvironment("org.apache.wayang.ml");
+ this.env.setTelemetry(false);
+ }
+
+ if (this.session == null) {
+ /*
+ OrtCUDAProviderOptions cudaProviderOptions = new OrtCUDAProviderOptions(0);
+ cudaProviderOptions.add("gpu_mem_limit","2147483648");
+ cudaProviderOptions.add("arena_extend_strategy","kSameAsRequested");
+ cudaProviderOptions.add("cudnn_conv_algo_search","DEFAULT");
+ cudaProviderOptions.add("do_copy_in_default_stream","1");
+ cudaProviderOptions.add("cudnn_conv_use_max_workspace","1");
+ cudaProviderOptions.add("cudnn_conv1d_pad_to_nc1d","1");
+ */
+
+ OrtSession.SessionOptions options = new OrtSession.SessionOptions();
+
+ options.setInterOpNumThreads(16);
+ options.setIntraOpNumThreads(16);
+ options.setDeterministicCompute(true);
+ //options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE);
+ //options.addCUDA(cudaProviderOptions);
+ this.session = env.createSession(filePath, options);
+ }
+ }
+
+ // Just here as placeholder
+ public double runModel(long[] encoded) {
+ return 0;
+ }
+
+ /**
+ * Close the session after running, {@link #closeSession()}
+ * @param encodedVector
+ * @return NaN on error, and a predicted cost on any other value.
+ * @throws OrtException
+ */
+ public double runModel(
+ Tuple, ArrayList> input1
+ ) throws OrtException {
+ double costPrediction;
+
+ Map inputInfoList = this.session.getInputInfo();
+ long[] input1Dims = ((TensorInfo) inputInfoList.get("input1").getInfo()).getShape();
+ long[] input2Dims = ((TensorInfo) inputInfoList.get("input2").getInfo()).getShape();
+ //
+ //long[] input1Dims = new long[]{1, input1.field0.get(0).length, input1.field0.get(0)[0].length};
+ //long[] input2Dims = new long[]{1, input1.field1.get(0).length, input1.field1.get(0)[0].length};
+
+ Instant start = Instant.now();
+ float[][][] inputValueStructure = new float[1][(int) input1Dims[1]][(int) input1Dims[2]];
+ long[][][] inputIndexStructure = new long[1][(int) input2Dims[1]][(int) input2Dims[2]];
+
+ //inputValueStructure = input1.field0.toArray(input1Left);
+ for (int i = 0; i < input1.field0.get(0).length; i++) {
+ for (int j = 0; j < input1.field0.get(0)[i].length; j++) {
+ inputValueStructure[0][i][j] = Long.valueOf(
+ input1.field0.get(0)[i][j]
+ ).floatValue();
+ }
+ }
+
+ for (int i = 0; i < input1.field1.get(0).length; i++) {
+ inputIndexStructure[0][i] = input1.field1.get(0)[i];
+ }
+
+ OnnxTensor tensorValues = OnnxTensor.createTensor(env, inputValueStructure);
+ OnnxTensor tensorIndexes = OnnxTensor.createTensor(env, inputIndexStructure);
+
+ this.inputMap.put("input1", tensorValues);
+ this.inputMap.put("input2", tensorIndexes);
+
+ this.requestedOutputs.add("output");
+
+ BiFunction unwrapFunc = (r, s) -> {
+ try {
+ return ((float[]) r.get(s).get().getValue())[0];
+ } catch (OrtException e) {
+ this.inputMap.clear();
+ this.requestedOutputs.clear();
+
+ return Float.NaN;
+ }
+ };
+
+
+ try (Result r = session.run(inputMap, requestedOutputs)) {
+ costPrediction = unwrapFunc.apply(r, "output");
+ Instant end = Instant.now();
+ long execTime = Duration.between(start, end).toMillis();
+
+ Logging.writeToFile(
+ String.format("%d", execTime),
+ this.configuration.getStringProperty("wayang.ml.optimizations.file")
+ );
+ } catch(Exception e) {
+ e.printStackTrace();
+ return 0;
+ } finally {
+ this.inputMap.clear();
+ this.requestedOutputs.clear();
+ }
+
+ return costPrediction;
+ }
+
+ public int runPairwise(
+ Tuple, ArrayList> input1,
+ Tuple, ArrayList> input2
+ ) throws OrtException {
+
+
+ Map inputInfoList = this.session.getInputInfo();
+ long[] input1Dims = ((TensorInfo) inputInfoList.get("input1").getInfo()).getShape();
+ long[] input2Dims = ((TensorInfo) inputInfoList.get("input2").getInfo()).getShape();
+ long[] input3Dims = ((TensorInfo) inputInfoList.get("input3").getInfo()).getShape();
+ long[] input4Dims = ((TensorInfo) inputInfoList.get("input4").getInfo()).getShape();
+
+ float[][][] inputValueStructure = new float[1][(int) input1Dims[1]][(int) input1Dims[2]];
+ long[][][] inputIndexStructure = new long[1][(int) input2Dims[1]][(int) input2Dims[2]];
+ float[][][] input2Left = new float[1][(int) input3Dims[1]][(int) input3Dims[2]];
+ long[][][] input2Right = new long[1][(int) input4Dims[1]][(int) input4Dims[2]];
+
+ for (int i = 0; i < input1.field0.get(0).length; i++) {
+ for (int j = 0; j < input1.field0.get(0)[i].length; j++) {
+ inputValueStructure[0][i][j] = Long.valueOf(
+ input1.field0.get(0)[i][j]
+ ).floatValue();
+ }
+ }
+
+ for (int i = 0; i < input1.field1.get(0).length; i++) {
+ inputIndexStructure[0][i] = input1.field1.get(0)[i];
+ }
+
+ for (int i = 0; i < input2.field0.get(0).length; i++) {
+ for (int j = 0; j < input2.field0.get(0)[i].length; j++) {
+ input2Left[0][i][j] = Long.valueOf(
+ input2.field0.get(0)[i][j]
+ ).floatValue();
+ }
+ }
+
+ for (int i = 0; i < input2.field1.get(0).length; i++) {
+ input2Right[0][i] = input2.field1.get(0)[i];
+ }
+
+ OnnxTensor tensorValues = OnnxTensor.createTensor(env, inputValueStructure);
+ OnnxTensor tensorIndexes = OnnxTensor.createTensor(env, inputIndexStructure);
+ OnnxTensor tensorTwoLeft = OnnxTensor.createTensor(env, input2Left);
+ OnnxTensor tensorTwoRight = OnnxTensor.createTensor(env, input2Right);
+
+ this.inputMap.put("input1", tensorValues);
+ this.inputMap.put("input2", tensorIndexes);
+ this.inputMap.put("input3", tensorTwoLeft);
+ this.inputMap.put("input4", tensorTwoRight);
+
+ this.requestedOutputs.add("output");
+
+ BiFunction unwrapFunc = (r, s) -> {
+ try {
+ float[] result = ((float[]) r.get(s).get().getValue());
+ Float[] convResult = new Float[result.length];
+
+ for (int i = 0; i < result.length; i++) {
+ convResult[i] = result[i];
+ }
+
+ return convResult;
+ } catch (OrtException e) {
+ this.inputMap.clear();
+ this.requestedOutputs.clear();
+
+ e.printStackTrace();
+ return new Float[]{Float.NaN};
+ }
+ };
+
+ try (Result r = session.run(this.inputMap, this.requestedOutputs)) {
+ Float[] result = unwrapFunc.apply(r, "output");
+
+ return Math.round(result[0]);
+ } catch (OrtException e) {
+ e.printStackTrace();
+
+ return 0;
+ } finally {
+ this.inputMap.clear();
+ this.requestedOutputs.clear();
+ }
+ }
+
+ public Tuple runVAE(
+ WayangPlan plan,
+ TreeNode encoded
+ ) throws OrtException {
+ Tuple, ArrayList> input = OrtTensorEncoder.encode(encoded);
+ Map inputInfoList = this.session.getInputInfo();
+ long[] input1Dims = ((TensorInfo) inputInfoList.get("input1").getInfo()).getShape();
+ long[] input2Dims = ((TensorInfo) inputInfoList.get("input2").getInfo()).getShape();
+
+ System.out.println(encoded.toStringEncoding());
+
+ //long[] input1Dims = new long[]{1, input.field0.get(0).length, input.field0.get(0)[0].length};
+ //long[] input2Dims = new long[]{1, input.field1.get(0).length, input.field1.get(0)[0].length};
+
+ System.out.println("Feature dims: " + Arrays.toString(input1Dims));
+ System.out.println("Index dims: " + Arrays.toString(input2Dims));
+
+ System.out.println("Tree size: " + encoded.size());
+
+ int indexDims = encoded.size();
+ long featureDims = input1Dims[1];
+ Instant start = Instant.now();
+
+ float[][] inputValueStructure = new float[(int) featureDims][(int) input1Dims[2]];
+ //long[][][] inputIndexStructure = new long[1][indexDims][1];
+ long[][][] inputIndexStructure = new long[1][(int) input2Dims[1]][(int) input2Dims[2]];
+
+
+ //inputValueStructure = input1.field0.toArray(input1Left);
+ for (int i = 0; i < input.field0.get(0).length; i++) {
+ for (int j = 0; j < input.field0.get(0)[i].length; j++) {
+ // 0th entry as the model could take multiple trees
+ // It only ever takes one here
+ inputValueStructure[i][j] = Long.valueOf(
+ input.field0.get(0)[i][j]
+ ).floatValue();
+ }
+ }
+ /*
+ long[][] inputIndexStructure = input.field1.get(0);
+ */
+
+ long[][] encoderIndexes = input.field1.get(0);
+
+ long maxIndex = Arrays.stream(encoderIndexes)
+ .flatMapToLong(Arrays::stream)
+ .max()
+ .orElseThrow(() -> new IllegalArgumentException("Encoder indexes are empty"));
+
+ assert maxIndex + 1 <= inputValueStructure[0].length : "There isn't a corresponding value for each index";
+
+ for (int i = 0; i < input.field1.get(0).length; i++) {
+ inputIndexStructure[0][i] = input.field1.get(0)[i];
+ }
+
+ OnnxTensor tensorValues = OnnxTensor.createTensor(env, new float[][][]{inputValueStructure});
+ OnnxTensor tensorIndexes = OnnxTensor.createTensor(env, inputIndexStructure);
+
+ OrtTensorDecoder decoder = new OrtTensorDecoder();
+
+ this.inputMap.put("input1", tensorValues);
+ this.inputMap.put("input2", tensorIndexes);
+
+ this.requestedOutputs.add("output");
+
+ BiFunction unwrapFunc = (r, s) -> {
+ try {
+ return ((float[][][]) r.get(s).get().getValue());
+ } catch (OrtException e) {
+ e.printStackTrace();
+ this.inputMap.clear();
+ this.requestedOutputs.clear();
+
+ return null;
+ }
+ };
+
+
+ try (Result r = session.run(inputMap, requestedOutputs)) {
+ float[][][] resultTensor = unwrapFunc.apply(r, "output");
+
+ Instant end = Instant.now();
+ long execTime = Duration.between(start, end).toMillis();
+
+ Logging.writeToFile(
+ String.format("Inference: %d", execTime),
+ this.configuration.getStringProperty("wayang.ml.optimizations.file")
+ );
+
+ start = Instant.now();
+
+ System.out.println("ResultTensor: " + resultTensor.length + ", " + resultTensor[0].length + ", " + resultTensor[0][0].length);
+
+ long[][] platformChoices = PlatformChoiceValidator.validate(
+ resultTensor,
+ inputIndexStructure,
+ encoded,
+ new BitmaskValidationRule(),
+ new OperatorValidationRule(),
+ new PostgresSourceValidationRule()
+ );
+
+ System.out.println("Choices: " + Arrays.deepToString(platformChoices));
+
+ int valueDim = resultTensor[0][0].length;
+ int indexDim = input.field1.get(0).length;
+
+ // Only handle one tree
+ //assert valueDim == indexDim : "Index dim " + indexDim + " != " + valueDim + " valueDim";
+
+ ArrayList mlResult = new ArrayList();
+ mlResult.add(platformChoices);
+
+ Tuple, ArrayList> decoderInput = new Tuple<>(mlResult, input.field1);
+ end = Instant.now();
+ execTime = Duration.between(start, end).toMillis();
+
+ Logging.writeToFile(
+ String.format("Unpacking: %d", execTime),
+ this.configuration.getStringProperty("wayang.ml.optimizations.file")
+ );
+
+ start = Instant.now();
+ TreeNode decoded = decoder.decode(decoderInput);
+
+ //decoded.softmax();
+ end = Instant.now();
+
+ execTime = Duration.between(start, end).toMillis();
+ Logging.writeToFile(
+ String.format("Decoding: %d", execTime),
+ this.configuration.getStringProperty("wayang.ml.optimizations.file")
+ );
+ // Now set the platforms on the wayangPlan
+ start = Instant.now();
+
+ assert decoded.size() == encoded.size() : "Mismatch in Decode and Encode tree sizes";
+
+ TreeNode reconstructed = encoded.withPlatformChoicesFrom(decoded);
+ WayangPlan decodedPlan = TreeDecoder.decode(reconstructed);
+ end = Instant.now();
+
+ execTime = Duration.between(start, end).toMillis();
+ /*Logging.writeToFile(
+ String.format("Reconstruction: %d", execTime),
+ this.configuration.getStringProperty("wayang.ml.optimizations.file")
+ )*/;
+
+
+
+ return new Tuple(decodedPlan, reconstructed);
+ } catch(Exception e) {
+ e.printStackTrace();
+ throw e;
+ //return new Tuple(plan, encoded);
+ } finally {
+ this.inputMap.clear();
+ this.requestedOutputs.clear();
+ this.closeSession();
+ }
+ }
+
+ /**
+ * Closes the OrtModel resource, relinquishing any underlying resources.
+ * @throws OrtException
+ */
+ public void closeSession() throws OrtException {
+ this.session.close();
+ this.env.close();
+ }
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/DSBenchmark.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/DSBenchmark.java
new file mode 100644
index 000000000..1e75392c7
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/DSBenchmark.java
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml.benchmarks;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.wayang.api.sql.context.SqlContext;
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.core.plan.wayangplan.PlanTraversal;
+import org.apache.wayang.core.plan.wayangplan.Operator;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.basic.operators.JoinOperator;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.spark.Spark;
+import org.apache.wayang.flink.Flink;
+import org.apache.wayang.postgres.Postgres;
+import org.apache.wayang.ml.MLContext;
+import org.apache.wayang.spark.Spark;
+import org.apache.logging.log4j.Level;
+import org.apache.wayang.api.DataQuanta;
+import org.apache.wayang.api.PlanBuilder;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.apps.util.Parameters;
+import org.apache.wayang.core.plugin.Plugin;
+import org.apache.wayang.ml.costs.PairwiseCost;
+import org.apache.wayang.ml.costs.PointwiseCost;
+import org.apache.wayang.core.plan.wayangplan.OutputSlot;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.basic.operators.TextFileSource;
+import org.apache.wayang.basic.operators.TableSource;
+import org.apache.wayang.basic.operators.MapOperator;
+import org.apache.wayang.basic.data.Record;
+import org.apache.wayang.core.util.ExplainUtils;
+import org.apache.wayang.api.sql.calcite.utils.PrintUtils;
+import org.apache.wayang.apps.tpch.queries.Query1Wayang;
+import org.apache.wayang.api.DataQuanta;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.calcite.sql.parser.SqlParseException;
+
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.io.StringWriter;
+import java.lang.reflect.Constructor;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.LinkedList;
+import scala.collection.Seq;
+import scala.collection.JavaConversions;
+import java.util.Collection;
+import java.io.BufferedWriter;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.sql.SQLException;
+
+public class DSBenchmark {
+
+ public static SqlContext sqlContext;
+
+ /**
+ * 0: platforms
+ * 1: Data directory
+ * 2: Directory to write timings to
+ * 3: query path
+ * 4: model type
+ * 5: model path
+ * 6: experience path
+ */
+ public static String psqlUser = "postgres";
+ public static String psqlPassword = "postgres";
+
+ public static void main(String[] args) {
+ try {
+ List plugins = JavaConversions.seqAsJavaList(Parameters.loadPlugins(args[0]));
+ Configuration config = new Configuration();
+ String modelType = "";
+
+ config.setProperty("spark.master", "spark://spark-cluster:7077");
+ config.setProperty("spark.app.name", "DSB Query");
+ config.setProperty("spark.rpc.message.maxSize", "2047");
+ config.setProperty("spark.executor.memory", "42g");
+ config.setProperty("spark.executor.cores", "4");
+ config.setProperty("spark.executor.instances", "2");
+ config.setProperty("spark.default.parallelism", "8");
+ config.setProperty("spark.driver.maxResultSize", "16g");
+ config.setProperty("spark.shuffle.service.enabled", "true");
+ config.setProperty("spark.dynamicAllocation.enabled", "true");
+ config.setProperty("spark.dynamicAllocation.minExecutors", "2");
+ config.setProperty("wayang.flink.mode.run", "distribution");
+ config.setProperty("wayang.flink.parallelism", "1");
+ config.setProperty("wayang.flink.master", "flink-cluster");
+ config.setProperty("wayang.flink.port", "7071");
+ config.setProperty("wayang.flink.rest.client.max-content-length", "200MiB");
+ config.setProperty("wayang.flink.collect.path", "file:///work/lsbo-paper/data/flink-data");
+ //config.setProperty("wayang.flink.collect.path", "file:///tmp/flink-data");
+ config.setProperty("wayang.ml.experience.enabled", "false");
+ config.setProperty(
+ "wayang.core.optimizer.pruning.strategies",
+ "org.apache.wayang.core.optimizer.enumeration.TopKPruningStrategy"
+ );
+ config.setProperty("wayang.core.optimizer.pruning.topk", "100");
+
+ final String calciteModel = "{\n" +
+ " \"version\": \"1.0\",\n" +
+ " \"defaultSchema\": \"wayang\",\n" +
+ " \"schemas\": [\n" +
+ " {\n" +
+ " \"name\": \"postgres\",\n" +
+ " \"type\": \"custom\",\n" +
+ " \"factory\": \"org.apache.wayang.api.sql.calcite.jdbc.JdbcSchema$Factory\",\n" +
+ " \"operand\": {\n" +
+ " \"jdbcDriver\": \"org.postgresql.Driver\",\n" +
+ " \"jdbcUrl\": \"jdbc:postgresql://dsb:5432/dsb\",\n" +
+ " \"jdbcUser\": \"" + psqlUser + "\",\n" +
+ " \"jdbcPassword\": \"" + psqlPassword + "\"\n" +
+ " }\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+
+ config.setProperty("org.apache.calcite.sql.parser.parserTracing", "true");
+ config.setProperty("wayang.calcite.model", calciteModel);
+ config.setProperty("wayang.postgres.jdbc.url", "jdbc:postgresql://dsb:5432/dsb");
+ config.setProperty("wayang.postgres.jdbc.user", psqlUser);
+ config.setProperty("wayang.postgres.jdbc.password", psqlPassword);
+
+ if (args.length > 4) {
+ modelType = args[4];
+ }
+
+ if (args.length > 6) {
+ DSBenchmark.setMLModel(config, modelType, args[5], args[6]);
+ }
+
+ // Take the query name
+ String fileName = args[3].substring(args[3].lastIndexOf("/") + 1);
+ String queryName = fileName.substring(0, fileName.lastIndexOf("."));
+
+ String executionTimeFile = args[2] + "query-executions-" + queryName;
+ String optimizationTimeFile = args[2] + "query-optimizations-" + queryName;
+
+ if (!"".equals(modelType)) {
+ executionTimeFile += "-" + modelType;
+ optimizationTimeFile += "-" + modelType;
+ }
+
+ config.setProperty(
+ "wayang.ml.executions.file",
+ executionTimeFile + ".txt"
+ );
+
+ config.setProperty(
+ "wayang.ml.optimizations.file",
+ optimizationTimeFile + ".txt"
+ );
+
+ final MLContext wayangContext = new MLContext(config);
+ plugins.stream().forEach(plug -> wayangContext.register(plug));
+
+ String[] jars = ArrayUtils.addAll(
+ ReflectionUtils.getAllJars(DSBenchmark.class),
+ ReflectionUtils.getLibs(DSBenchmark.class)
+ );
+
+ jars = ArrayUtils.addAll(
+ jars,
+ ReflectionUtils.getAllJars(org.apache.calcite.runtime.SqlFunctions.class)
+ );
+
+ try {
+ WayangPlan plan = DSBenchmark.getWayangPlan(args[3], config, plugins.toArray(Plugin[]::new), jars);
+
+ wayangContext.setLogLevel(Level.DEBUG);
+
+ if (!"vae".equals(modelType) && !"bvae".equals(modelType)) {
+ System.out.println("Executing query " + args[3]);
+ wayangContext.execute(plan, jars);
+ System.out.println("Finished execution");
+ } else {
+ System.out.println("Using vae cost model");
+ System.out.println("Executing query " + args[3]);
+ wayangContext.executeVAE(plan, jars);
+ System.out.println("Finished execution");
+ }
+ } catch (SqlParseException sqlE) {
+ sqlE.printStackTrace();
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ private static void setMLModel(Configuration config, String modelType, String path, String experiencePath) {
+ config.setProperty(
+ "wayang.ml.model.file",
+ path
+ );
+
+ switch(modelType) {
+ case "cost":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-cost.txt");
+
+ config.setCostModel(new PointwiseCost());
+ System.out.println("Using cost ML Model");
+
+ break;
+ case "pairwise":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-pairwise.txt");
+ config.setCostModel(new PairwiseCost());
+
+ System.out.println("Using pairwise ML Model");
+ break;
+ case "bvae":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-bvae.txt");
+
+ System.out.println("Using bvae ML Model");
+ break;
+ case "vae":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-vae.txt");
+
+ System.out.println("Using vae ML Model");
+ break;
+ default:
+ System.out.println("Using default cost Model");
+ break;
+ }
+
+ }
+
+ public static WayangPlan getWayangPlan(
+ final String path,
+ final Configuration configuration,
+ final Plugin[] plugins,
+ final String... udfJars
+ ) throws SQLException, IOException, org.apache.calcite.sql.parser.SqlParseException {
+ sqlContext = new SqlContext(configuration, plugins);
+ final Path pathToQuery = Paths.get(path);
+
+ // need to chop off the last ';' otherwise sqlContext cant parse it
+ final String query = StringUtils.chop(Files.readString(pathToQuery).stripTrailing());
+
+ WayangPlan plan = sqlContext.buildWayangPlan(query, udfJars);
+
+ return plan;
+ }
+
+
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/GeneratableBenchmarks.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/GeneratableBenchmarks.java
new file mode 100644
index 000000000..1d7b3849d
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/GeneratableBenchmarks.java
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml.benchmarks;
+
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.core.plan.wayangplan.PlanTraversal;
+import org.apache.wayang.core.plan.wayangplan.Operator;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.flink.Flink;
+import org.apache.wayang.ml.MLContext;
+import org.apache.wayang.spark.Spark;
+import org.apache.wayang.api.DataQuanta;
+import org.apache.wayang.api.PlanBuilder;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.apps.util.Parameters;
+import org.apache.wayang.core.plugin.Plugin;
+import org.apache.wayang.ml.costs.PairwiseCost;
+import org.apache.wayang.ml.costs.PointwiseCost;
+import org.apache.wayang.ml.training.TPCH;
+import org.apache.wayang.apps.tpch.queries.Query1Wayang;
+import org.apache.wayang.apps.tpch.queries.Query3;
+import org.apache.wayang.apps.tpch.queries.Query5;
+import org.apache.wayang.apps.tpch.queries.Query6;
+import org.apache.wayang.apps.tpch.queries.Query10;
+import org.apache.wayang.apps.tpch.queries.Query12;
+import org.apache.wayang.apps.tpch.queries.Query14;
+import org.apache.wayang.apps.tpch.queries.Query19;
+import org.apache.wayang.ml.training.GeneratableJob;
+import org.apache.wayang.ml.util.Jobs;
+import org.apache.wayang.api.DataQuanta;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.lang.reflect.Constructor;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.LinkedList;
+import scala.collection.Seq;
+import scala.collection.JavaConversions;
+import java.util.Collection;
+import java.io.BufferedWriter;
+import java.io.FileWriter;
+
+public class GeneratableBenchmarks {
+
+ /**
+ * 0: platforms
+ * 1: TPCH data set directory path
+ * 2: Directory to write timings to
+ * 3: query number
+ * 4: model type
+ * 5: model path
+ * 6: experience path
+ */
+ public static void main(String[] args) {
+ try {
+ List plugins = JavaConversions.seqAsJavaList(Parameters.loadPlugins(args[0]));
+ Class extends GeneratableJob> job = Jobs.getJob(Integer.parseInt(args[3]));
+ System.out.println("Job: " + job.getName());
+ Configuration config = new Configuration();
+ String modelType = "";
+
+ config.setProperty("spark.master", "spark://spark-cluster:7077");
+ config.setProperty("spark.app.name", "JOB Query");
+ config.setProperty("spark.rpc.message.maxSize", "2047");
+
+ // Executor memory (on 48GB nodes)
+ config.setProperty("spark.executor.memory", "36g");
+ config.setProperty("spark.executor.memoryOverhead", "4g");
+ config.setProperty("spark.executor.cores", "8");
+
+ // Driver memory (on 48GB master node)
+ config.setProperty("spark.driver.memory", "24g");
+ config.setProperty("spark.driver.memoryOverhead", "2g");
+ config.setProperty("spark.driver.maxResultSize", "8g");
+
+ // Dynamic allocation (drop spark.executor.instances)
+ config.setProperty("spark.dynamicAllocation.enabled", "true");
+ config.setProperty("spark.dynamicAllocation.minExecutors", "2");
+ config.setProperty("spark.dynamicAllocation.maxExecutors", "2");
+ config.setProperty("spark.dynamicAllocation.initialExecutors", "2");
+ config.setProperty("spark.shuffle.service.enabled", "true");
+
+ // Parallelism
+ config.setProperty("spark.default.parallelism", "16"); // 2 executors * 8 cores
+
+ config.setProperty("wayang.flink.mode.run", "distribution");
+ config.setProperty("wayang.flink.parallelism", "1");
+ config.setProperty("wayang.flink.master", "flink-cluster");
+ config.setProperty("wayang.flink.port", "7071");
+ config.setProperty("wayang.flink.rest.client.max-content-length", "200MiB");
+ config.setProperty("wayang.flink.collect.path", "file:///work/lsbo-paper/data/flink-data");
+ //config.setProperty("wayang.flink.collect.path", "file:///tmp/flink-data");
+ config.setProperty("wayang.ml.experience.enabled", "false");
+ config.setProperty(
+ "wayang.core.optimizer.pruning.strategies",
+ "org.apache.wayang.core.optimizer.enumeration.TopKPruningStrategy"
+ );
+ config.setProperty("wayang.core.optimizer.pruning.topk", "100");
+
+ if (args.length > 4) {
+ modelType = args[4];
+ }
+
+ if (args.length > 6) {
+ GeneratableBenchmarks.setMLModel(config, modelType, args[5], args[6]);
+ }
+
+ String executionTimeFile = args[2] + "query" + args[3] + "-executions";
+ String optimizationTimeFile = args[2] + "query" + args[3] + "-optimizations";
+
+ if (!"".equals(modelType)) {
+ executionTimeFile += "-" + modelType;
+ optimizationTimeFile += "-" + modelType;
+ }
+
+ config.setProperty(
+ "wayang.ml.executions.file",
+ executionTimeFile + ".txt"
+ );
+
+ config.setProperty(
+ "wayang.ml.optimizations.file",
+ optimizationTimeFile + ".txt"
+ );
+
+ final MLContext wayangContext = new MLContext(config);
+ plugins.stream().forEach(plug -> wayangContext.register(plug));
+
+ Constructor> cnstr = job.getDeclaredConstructors()[0];
+ GeneratableJob createdJob = (GeneratableJob) cnstr.newInstance();
+ String[] jobArgs = {args[0], args[1]};
+ DataQuanta> quanta = createdJob.buildPlan(jobArgs);
+ PlanBuilder builder = quanta.getPlanBuilder();
+ WayangPlan plan = builder.build();
+
+ //Set sink to be on Java
+ //((LinkedList )plan.getSinks()).get(0).addTargetPlatform(Java.platform());
+
+ String[] jars = ArrayUtils.addAll(
+ ReflectionUtils.getAllJars(GeneratableBenchmarks.class),
+ ReflectionUtils.getAllJars(org.apache.calcite.rel.externalize.RelJson.class)
+ );
+
+ jars = ArrayUtils.addAll(
+ jars,
+ ReflectionUtils.getAllJars(DataQuanta.class)
+ );
+
+ /*
+ FileWriter fw = new FileWriter(
+ "/var/www/html/data/benchmarks/operators.txt",
+ true
+ );
+ BufferedWriter writer = new BufferedWriter(fw);
+
+ final Collection operators = PlanTraversal.upstream().traverse(plan.getSinks()).getTraversedNodes();
+
+ System.out.println("Operators: " + operators.size());
+
+ writer.write(args[3] + ": " + operators.size());
+ writer.newLine();
+ writer.flush();
+ writer.close();
+ */
+ System.out.println(modelType);
+ if (!"vae".equals(modelType) && !"bvae".equals(modelType)) {
+ System.out.println("Executing query " + args[3]);
+ wayangContext.execute(plan, jars);
+ System.out.println("Finished execution");
+ } else {
+ System.out.println("Using vae cost model");
+ System.out.println("Executing query " + args[3]);
+ wayangContext.executeVAE(plan, jars);
+ System.out.println("Finished execution");
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ private static void setMLModel(Configuration config, String modelType, String path, String experiencePath) {
+ config.setProperty(
+ "wayang.ml.model.file",
+ path
+ );
+
+ switch(modelType) {
+ case "cost":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-cost.txt");
+
+ config.setCostModel(new PointwiseCost());
+ System.out.println("Using cost ML Model");
+
+ break;
+ case "pairwise":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-pairwise.txt");
+ config.setCostModel(new PairwiseCost());
+
+ System.out.println("Using pairwise ML Model");
+ break;
+ case "bvae":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-bvae.txt");
+
+ System.out.println("Using bvae ML Model");
+ break;
+ case "vae":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-vae.txt");
+
+ System.out.println("Using vae ML Model");
+ break;
+ default:
+ System.out.println("Using default cost Model");
+ break;
+ }
+
+ }
+
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/IMDBJOBenchmark.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/IMDBJOBenchmark.java
new file mode 100644
index 000000000..8e29524b1
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/IMDBJOBenchmark.java
@@ -0,0 +1,337 @@
+package org.apache.wayang.ml.benchmarks;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.wayang.api.sql.context.SqlContext;
+import org.apache.wayang.basic.data.JVMRecord;
+import org.apache.wayang.basic.data.Record;
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.core.plan.wayangplan.Operator;
+import org.apache.wayang.core.plan.wayangplan.OutputSlot;
+import org.apache.wayang.core.plan.wayangplan.PlanTraversal;
+import org.apache.wayang.core.plugin.Plugin;
+import org.apache.wayang.flink.Flink;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.postgres.Postgres;
+import org.apache.wayang.spark.Spark;
+import org.apache.wayang.basic.operators.TextFileSource;
+import org.apache.wayang.basic.operators.TableSource;
+import org.apache.wayang.basic.operators.MapOperator;
+import org.apache.wayang.basic.data.JVMRecord;
+import org.apache.wayang.basic.types.RecordType;
+import org.apache.wayang.core.function.TransformationDescriptor;
+import org.apache.wayang.core.types.DataSetType;
+import org.apache.wayang.core.types.DataUnitType;
+import org.apache.wayang.apps.imdb.data.*;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.sql.SQLException;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+import java.util.Arrays;
+import java.util.stream.Stream;
+import java.util.stream.Collectors;
+import org.apache.commons.lang3.ArrayUtils;
+
+public class IMDBJOBenchmark {
+ public static SqlContext sqlContext;
+
+ public static final int MAX_SOURCES_REPLACED = 3;
+
+ public static WayangPlan getWayangPlan(
+ final String path,
+ final Configuration configuration,
+ final Plugin[] plugins,
+ final String... udfJars
+ ) throws SQLException, IOException, org.apache.calcite.sql.parser.SqlParseException {
+
+ sqlContext = new SqlContext(configuration, Spark.basicPlugin(), Postgres.plugin(),
+ Java.channelConversionPlugin(), Java.testSuitePlugin(), Flink.basicPlugin());
+
+ final Path pathToQuery = Paths.get(path);
+ final String query = StringUtils.chop(Files.readString(pathToQuery).stripTrailing()); // need to chop off
+ /*
+ sqlContext = new SqlContext(configuration, plugins);
+ final Path pathToQuery = Paths.get(path);
+
+ // need to chop off the last ';' otherwise sqlContext cant parse it
+ final String query = StringUtils.chop(Files.readString(pathToQuery).stripTrailing());
+ */
+
+ WayangPlan plan = sqlContext.buildWayangPlan(query, udfJars);
+
+ //((LinkedList )plan.getSinks()).get(0).addTargetPlatform(Java.platform());
+
+ return plan;
+ }
+
+ /**
+ * Benchmarking tool for the imdb/jo benchmark, Calcite dictates that every jo
+ * query follows
+ * the schema, "schema_name.table_name". The tool searches for the queries in
+ * resources/calcite-ready-job-queries
+ *
+ * @param args args[0]: path to calcite-job-ready-queries/*.sql
+ */
+ public static void main(final String[] args) throws Exception {
+ try {
+ final Configuration configuration = new Configuration();
+
+ final String calciteModel = "{\n" +
+ " \"version\": \"1.0\",\n" +
+ " \"defaultSchema\": \"wayang\",\n" +
+ " \"schemas\": [\n" +
+ " {\n" +
+ " \"name\": \"postgres\",\n" +
+ " \"type\": \"custom\",\n" +
+ " \"factory\": \"org.apache.wayang.api.sql.calcite.jdbc.JdbcSchema$Factory\",\n" +
+ " \"operand\": {\n" +
+ " \"jdbcDriver\": \"org.postgresql.Driver\",\n" +
+ " \"jdbcUrl\": \"jdbc:postgresql://job:5432/job\",\n" +
+ " \"jdbcUser\": \"postgres\",\n" +
+ " \"jdbcPassword\": \"postgres\"\n" +
+ " }\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+
+ configuration.setProperty("org.apache.calcite.sql.parser.parserTracing", "true");
+ configuration.setProperty("wayang.calcite.model", calciteModel);
+ configuration.setProperty("wayang.postgres.jdbc.url", "jdbc:postgresql://job:5432/job");
+ configuration.setProperty("wayang.postgres.jdbc.user", "postgres");
+ configuration.setProperty("wayang.postgres.jdbc.password", "postgres");
+
+ configuration.setProperty(
+ "wayang.ml.executions.file",
+ "mle" + ".txt");
+
+ configuration.setProperty(
+ "wayang.ml.optimizations.file",
+ "mlo" + ".txt");
+
+ configuration.setProperty("wayang.ml.experience.enabled", "false");
+
+ final SqlContext sqlContext = new SqlContext(configuration, Spark.basicPlugin(), Postgres.plugin(),
+ Java.channelConversionPlugin(), Java.testSuitePlugin(), Flink.basicPlugin());
+
+ final Path pathToQuery = Paths.get(args[0]);
+ final String query = StringUtils.chop(Files.readString(pathToQuery).stripTrailing()); // need to chop off
+ // the last
+ // ';' otherwise sqlContext
+ // cant parse it
+
+ final Collection result = sqlContext.executeSql(
+ query);
+
+ //System.out.println(result.stream().limit(50).collect(Collectors.toList()));
+ //System.out.println("\nResults: " + " amount of records: " + result.size());
+ } catch (Exception e) {
+ e.printStackTrace();
+ System.exit(5);
+ }
+ }
+
+ // Only source in postgres, compute elsewhere
+ public static void setSources(WayangPlan plan, String dataPath) {
+
+ final List sources = plan.collectReachableTopLevelSources()
+ .stream()
+ .map(op -> (TableSource) op)
+ .sorted(Comparator.comparing(op -> op.getTableName()))
+ .collect(Collectors.toList());
+
+ Set replacedSources = new HashSet<>();
+ boolean isSet = false;
+ int nrOfSourcesReplaced = 0;
+
+ for (Operator op : sources) {
+ if (nrOfSourcesReplaced >= MAX_SOURCES_REPLACED) {
+ return;
+ }
+
+ if (op instanceof TableSource) {
+ String tableName = ((TableSource) op).getTableName();
+ String filePath = dataPath + tableName + ".csv";
+ if (!replacedSources.contains(tableName)) {
+ //if (!isSet) {
+ TextFileSource replacement = new TextFileSource(filePath, "UTF-8");
+
+ MapOperator parser;
+
+ switch (tableName) {
+ case "movie_companies": parser = new MapOperator<>(
+ (line) -> {
+ return new JVMRecord(MovieCompanies.toArray(MovieCompanies.parseCsv(line)));
+ },
+ String.class,
+ JVMRecord.class
+ );
+ OutputSlot.stealConnections(op, parser);
+ //System.out.println("Setting to file: " + tableName);
+ replacement.connectTo(0, parser, 0);
+ replacedSources.add(tableName);
+ isSet = true;
+ nrOfSourcesReplaced++;
+
+ break;
+ case "aka_name":
+ parser = new MapOperator<>(
+ (line) -> {
+ return new JVMRecord(AkaName.toArray(AkaName.parseCsv(line)));
+ },
+ String.class,
+ JVMRecord.class
+ );
+ OutputSlot.stealConnections(op, parser);
+ //System.out.println("Setting to file: " + tableName);
+ replacedSources.add(tableName);
+ replacement.connectTo(0, parser, 0);
+ isSet = true;
+ nrOfSourcesReplaced++;
+
+ break;
+ case "comp_cast_type":
+ parser = new MapOperator<>(
+ (line) -> { return new JVMRecord(CompCastType.toArray(CompCastType.parseCsv(line)));
+ },
+ String.class,
+ JVMRecord.class
+ );
+ OutputSlot.stealConnections(op, parser);
+
+ replacement.connectTo(0, parser, 0);
+ replacedSources.add(tableName);
+ isSet = true;
+ nrOfSourcesReplaced++;
+ //System.out.println("Setting to file: " + tableName);
+ break;
+ case "company_name":
+ parser = new MapOperator<>(
+ (line) -> {
+ return new JVMRecord(CompanyName.toArray(CompanyName.parseCsv(line)));
+ },
+ String.class,
+ JVMRecord.class
+ );
+ OutputSlot.stealConnections(op, parser);
+
+ replacement.connectTo(0, parser, 0);
+ replacedSources.add(tableName);
+ isSet = true;
+ nrOfSourcesReplaced++;
+
+ //System.out.println("Setting to file: " + tableName);
+ break;
+ case "info_type":
+ parser = new MapOperator<>(
+ (line) -> {
+ return new JVMRecord(InfoType.toArray(InfoType.parseCsv(line)));
+ },
+ String.class,
+ JVMRecord.class
+ );
+ OutputSlot.stealConnections(op, parser);
+
+ replacement.connectTo(0, parser, 0);
+ replacedSources.add(tableName);
+ isSet = true;
+ nrOfSourcesReplaced++;
+ //System.out.println("Setting to file: " + tableName);
+ break;
+ case "movie_info":
+ parser = new MapOperator<>(
+ (line) -> {
+ return new JVMRecord(MovieInfo.toArray(MovieInfo.parseCsv(line)));
+ },
+
+ String.class,
+ JVMRecord.class
+ );
+ OutputSlot.stealConnections(op, parser);
+
+ replacement.connectTo(0, parser, 0);
+ replacedSources.add(tableName);
+ isSet = true;
+ nrOfSourcesReplaced++;
+ //System.out.println("Setting to file: " + tableName);
+ break;
+ case "person_info":
+ parser = new MapOperator<>(
+ (line) -> {
+ return new JVMRecord(PersonInfo.toArray(PersonInfo.parseCsv(line)));
+ },
+ String.class,
+ JVMRecord.class
+ );
+ OutputSlot.stealConnections(op, parser);
+
+ replacement.connectTo(0, parser, 0);
+ replacedSources.add(tableName);
+ isSet = true;
+ nrOfSourcesReplaced++;
+ //System.out.println("Setting to file: " + tableName);
+ break;
+ case "movie_keyword":
+ parser = new MapOperator<>(
+ (line) -> {
+ return new JVMRecord(MovieKeyword.toArray(MovieKeyword.parseCsv(line)));
+ },
+ String.class,
+ JVMRecord.class
+ );
+
+ OutputSlot.stealConnections(op, parser);
+
+ replacement.connectTo(0, parser, 0);
+ replacedSources.add(tableName);
+ isSet = true;
+ nrOfSourcesReplaced++;
+ //System.out.println("Setting to file: " + tableName);
+
+ break;
+ case "cast_info":
+ parser = new MapOperator<>(
+ (line) -> {
+ return new JVMRecord(CastInfo.toArray(CastInfo.parseCsv(line)));
+ },
+ String.class,
+ JVMRecord.class
+ );
+ OutputSlot.stealConnections(op, parser);
+
+ replacement.connectTo(0, parser, 0);
+ replacedSources.add(tableName);
+ isSet = true;
+ nrOfSourcesReplaced++;
+ //System.out.println("Setting to file: " + tableName);
+ break;
+ case "movie_link":
+ parser = new MapOperator<>(
+ (line) -> {
+ return new JVMRecord(MovieLink.toArray(MovieLink.parseCsv(line)));
+ },
+ String.class,
+ JVMRecord.class
+ );
+ OutputSlot.stealConnections(op, parser);
+
+ replacement.connectTo(0, parser, 0);
+ replacedSources.add(tableName);
+ isSet = true;
+ nrOfSourcesReplaced++;
+ //System.out.println("Setting to file: " + tableName);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/Inference.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/Inference.java
new file mode 100644
index 000000000..855138d3b
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/Inference.java
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml.benchmarks;
+
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.core.plan.wayangplan.PlanTraversal;
+import org.apache.wayang.core.plan.executionplan.ExecutionPlan;
+import org.apache.wayang.core.plan.wayangplan.Operator;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.basic.operators.JoinOperator;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.spark.Spark;
+import org.apache.wayang.flink.Flink;
+import org.apache.wayang.postgres.Postgres;
+import org.apache.wayang.ml.MLContext;
+import org.apache.wayang.spark.Spark;
+import org.apache.logging.log4j.Level;
+import org.apache.wayang.api.DataQuanta;
+import org.apache.wayang.api.PlanBuilder;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.apps.util.Parameters;
+import org.apache.wayang.core.plugin.Plugin;
+import org.apache.wayang.ml.costs.PairwiseCost;
+import org.apache.wayang.ml.costs.PointwiseCost;
+import org.apache.wayang.ml.encoding.TreeEncoder;
+import org.apache.wayang.core.plan.wayangplan.Operator;
+import org.apache.wayang.core.plan.wayangplan.OutputSlot;
+import org.apache.wayang.core.plan.wayangplan.PlanTraversal;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.basic.operators.TextFileSource;
+import org.apache.wayang.basic.operators.TableSource;
+import org.apache.wayang.basic.operators.MapOperator;
+import org.apache.wayang.basic.data.Record;
+import org.apache.wayang.core.util.ExplainUtils;
+import org.apache.wayang.api.sql.calcite.utils.PrintUtils;
+import org.apache.wayang.api.sql.context.SqlContext;
+import org.apache.wayang.apps.tpch.queries.Query1Wayang;
+import org.apache.wayang.api.DataQuanta;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.calcite.sql.parser.SqlParseException;
+import org.apache.wayang.core.api.Job;
+import org.apache.wayang.ml.encoding.OneHotMappings;
+
+import java.io.StringWriter;
+import java.lang.reflect.Constructor;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.LinkedList;
+import scala.collection.Seq;
+import scala.collection.JavaConversions;
+import java.util.Collection;
+import java.io.BufferedWriter;
+import java.io.FileWriter;
+
+public class Inference {
+
+ /**
+ * 0: platforms
+ * 1: Data directory
+ * 2: Directory to write timings to
+ * 3: query path
+ * 4: model type
+ * 5: model path
+ * 6: experience path
+ */
+ public static String psqlUser = "postgres";
+ public static String psqlPassword = "postgres";
+
+ public static void main(String[] args) {
+ try {
+ List plugins = JavaConversions.seqAsJavaList(Parameters.loadPlugins(args[0]));
+ Configuration config = new Configuration();
+ String modelType = "";
+
+ config.setProperty("spark.master", "spark://spark-cluster:7077");
+ config.setProperty("spark.app.name", "JOB Query");
+ config.setProperty("spark.rpc.message.maxSize", "2047");
+ config.setProperty("spark.executor.memory", "32g");
+ config.setProperty("spark.executor.cores", "4");
+ config.setProperty("spark.executor.instances", "1");
+ config.setProperty("spark.default.parallelism", "8");
+ config.setProperty("spark.driver.maxResultSize", "16g");
+ config.setProperty("spark.dynamicAllocation.enabled", "true");
+ config.setProperty("wayang.flink.mode.run", "distribution");
+ config.setProperty("wayang.flink.parallelism", "1");
+ config.setProperty("wayang.flink.master", "flink-cluster");
+ config.setProperty("wayang.flink.port", "7071");
+ config.setProperty("wayang.flink.rest.client.max-content-length", "200MiB");
+ config.setProperty("wayang.flink.collect.path", "file:///work/lsbo-paper/data/flink-data");
+ //config.setProperty("wayang.flink.collect.path", "file:///tmp/flink-data");
+ config.setProperty("wayang.ml.experience.enabled", "false");
+ config.setProperty(
+ "wayang.core.optimizer.pruning.strategies",
+ "org.apache.wayang.core.optimizer.enumeration.TopKPruningStrategy"
+ );
+ config.setProperty("wayang.core.optimizer.pruning.topk", "1000");
+
+ final String calciteModel = "{\n" +
+ " \"version\": \"1.0\",\n" +
+ " \"defaultSchema\": \"wayang\",\n" +
+ " \"schemas\": [\n" +
+ " {\n" +
+ " \"name\": \"postgres\",\n" +
+ " \"type\": \"custom\",\n" +
+ " \"factory\": \"org.apache.wayang.api.sql.calcite.jdbc.JdbcSchema$Factory\",\n" +
+ " \"operand\": {\n" +
+ " \"jdbcDriver\": \"org.postgresql.Driver\",\n" +
+ " \"jdbcUrl\": \"jdbc:postgresql://job:5432/job\",\n" +
+ " \"jdbcUser\": \"" + psqlUser + "\",\n" +
+ " \"jdbcPassword\": \"" + psqlPassword + "\"\n" +
+ " }\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+
+ config.setProperty("org.apache.calcite.sql.parser.parserTracing", "true");
+ config.setProperty("wayang.calcite.model", calciteModel);
+ config.setProperty("wayang.postgres.jdbc.url", "jdbc:postgresql://job:5432/job");
+ config.setProperty("wayang.postgres.jdbc.user", psqlUser);
+ config.setProperty("wayang.postgres.jdbc.password", psqlPassword);
+
+ if (args.length > 4) {
+ modelType = args[4];
+ }
+
+ if (args.length > 6) {
+ Inference.setMLModel(config, modelType, args[5], args[6]);
+ }
+
+ // Take the query name
+ String fileName = args[3].substring(args[3].lastIndexOf("/") + 1);
+ String queryName = fileName.substring(0, fileName.lastIndexOf("."));
+
+ String executionTimeFile = args[2] + "query-executions-" + queryName;
+ String optimizationTimeFile = args[2] + "query-optimizations-" + queryName;
+
+ if (!"".equals(modelType)) {
+ executionTimeFile += "-" + modelType;
+ optimizationTimeFile += "-" + modelType;
+ }
+
+ config.setProperty(
+ "wayang.ml.executions.file",
+ executionTimeFile + ".txt"
+ );
+
+ config.setProperty(
+ "wayang.ml.optimizations.file",
+ optimizationTimeFile + ".txt"
+ );
+
+ final MLContext wayangContext = new MLContext(config);
+ plugins.stream().forEach(plug -> wayangContext.register(plug));
+
+ String[] jars = ArrayUtils.addAll(
+ ReflectionUtils.getAllJars(JOBenchmark.class),
+ ReflectionUtils.getLibs(JOBenchmark.class)
+ );
+
+ jars = ArrayUtils.addAll(
+ jars,
+ ReflectionUtils.getAllJars(org.apache.calcite.runtime.SqlFunctions.class)
+ );
+
+ try {
+ WayangPlan plan = IMDBJOBenchmark.getWayangPlan(args[3], config, plugins.toArray(Plugin[]::new), jars);
+
+ IMDBJOBenchmark.setSources(plan, args[1]);
+
+ ExecutionPlan executionPlan;
+
+ if (!"vae".equals(modelType) && !"bvae".equals(modelType)) {
+ executionPlan = wayangContext.buildInitialExecutionPlan("", plan, jars);
+
+ Job job = wayangContext.createJob("", plan, jars);
+ Configuration jobConfig = job.getConfiguration();
+ job.estimateKeyFigures();
+ OneHotMappings.setOptimizationContext(job.getOptimizationContext());
+ OneHotMappings.encodeIds = false;
+ } else {
+ OneHotMappings.encodeIds = true;
+ executionPlan = wayangContext.buildWithVAE(plan, jars);
+ OneHotMappings.encodeIds = false;
+ }
+
+ ExplainUtils.parsePlan(executionPlan, true);
+ //System.out.println(TreeEncoder.encode(executionPlan, true).toStringEncoding());
+ System.out.println("DONE");
+ } catch (SqlParseException sqlE) {
+ sqlE.printStackTrace();
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ private static void setMLModel(Configuration config, String modelType, String path, String experiencePath) {
+ config.setProperty(
+ "wayang.ml.model.file",
+ path
+ );
+
+ switch(modelType) {
+ case "cost":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-cost.txt");
+
+ config.setCostModel(new PointwiseCost());
+
+ break;
+ case "pairwise":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-pairwise.txt");
+ config.setCostModel(new PairwiseCost());
+
+ break;
+ case "bvae":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-bvae.txt");
+
+ break;
+ case "vae":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-vae.txt");
+
+ break;
+ default:
+ break;
+ }
+
+ }
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/JOBLightQuery1.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/JOBLightQuery1.java
new file mode 100644
index 000000000..bf03a8846
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/JOBLightQuery1.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml.benchmarks;
+
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.basic.data.Tuple2;
+import org.apache.wayang.basic.operators.*;
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.core.function.FlatMapDescriptor;
+import org.apache.wayang.core.function.ReduceDescriptor;
+import org.apache.wayang.core.function.TransformationDescriptor;
+import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.core.types.DataSetType;
+import org.apache.wayang.core.types.DataUnitType;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.java.platform.JavaPlatform;
+import org.apache.wayang.spark.Spark;
+import org.apache.wayang.spark.platform.SparkPlatform;
+import org.apache.wayang.ml.MLContext;
+import org.apache.wayang.ml.costs.MLCost;
+import org.apache.wayang.ml.costs.PairwiseCost;
+import org.apache.wayang.ml.costs.PointwiseCost;
+import org.apache.logging.log4j.Level;
+import org.apache.wayang.apps.util.Parameters;
+import org.apache.wayang.core.plugin.Plugin;
+import org.apache.wayang.apps.imdb.data.*;
+import org.apache.commons.lang3.ArrayUtils;
+
+import scala.collection.Seq;
+import scala.collection.JavaConversions;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.List;
+
+/**
+ * Example Apache Wayang (incubating) App that does a word count -- the Hello World of Map/Reduce-like systems.
+ */
+public class JOBLightQuery1 {
+
+ /**
+ * Creates the {@link WayangPlan} for the word count app.
+ *
+ * @param inputFileUrl the file whose words should be counted
+ */
+ public static WayangPlan createWayangPlan(String inputFilesUrl, Collection, MovieInfo>> collector) throws URISyntaxException, IOException {
+ // TextFileSources
+ TextFileSource titleSource = new TextFileSource(inputFilesUrl + "/title.csv");
+ titleSource.setName("Load title file");
+
+ TextFileSource movieCompaniesSource = new TextFileSource(inputFilesUrl + "/movie_companies.csv");
+ movieCompaniesSource.setName("Load movie_companies file");
+
+ TextFileSource movieInfoIdxSource = new TextFileSource(inputFilesUrl + "/movie_info_idx.csv");
+ movieInfoIdxSource.setName("Load movie_info_idx file");
+
+ // Parsing
+ MapOperator titleParser = new MapOperator<>(
+ (line) -> Title.parseCsv(line), String.class, Title.class
+ );
+
+ MapOperator mcParser = new MapOperator<>(
+ (line) -> MovieCompanies.parseCsv(line), String.class, MovieCompanies.class
+ );
+
+ MapOperator miIdxParser = new MapOperator<>(
+ (line) -> MovieInfo.parseCsv(line), String.class, MovieInfo.class
+ );
+
+ // Filters pushed down
+ FilterOperator miIdxFilter = new FilterOperator<>(
+ (tuple) -> tuple.infoTypeId() == 112,
+ MovieInfo.class
+ );
+
+ FilterOperator mcFilter = new FilterOperator<>(
+ (tuple) -> tuple.companyTypeId() == 2,
+ MovieCompanies.class
+ );
+
+
+ // Joins
+ JoinOperator tMcJoin = new JoinOperator<>(
+ (title) -> title.id(),
+ (mc) -> mc.movieId(),
+ Title.class,
+ MovieCompanies.class,
+ Integer.class
+ );
+
+ JoinOperator, MovieInfo, Integer> tMcMiIdxJoin = new JoinOperator<>(
+ (tuple) -> tuple.field0.id(),
+ (miIdx) -> miIdx.movieId(),
+ ReflectionUtils.specify(Tuple2.class),
+ MovieInfo.class,
+ Integer.class
+ );
+
+ // Sink
+ LocalCallbackSink, MovieInfo>> sink = LocalCallbackSink.createCollectingSink(
+ collector,
+ ReflectionUtils.specify(Tuple2.class)
+ );
+
+
+ // Operator connections
+ titleSource.connectTo(0, titleParser, 0);
+ movieCompaniesSource.connectTo(0, mcParser, 0);
+ movieInfoIdxSource.connectTo(0, miIdxParser, 0);
+
+ mcParser.connectTo(0, mcFilter, 0);
+ miIdxParser.connectTo(0, miIdxFilter, 0);
+
+ titleParser.connectTo(0, tMcJoin, 0);
+ mcFilter.connectTo(0, tMcJoin, 1);
+
+ tMcJoin.connectTo(0, tMcMiIdxJoin, 0);
+ miIdxFilter.connectTo(0, tMcMiIdxJoin, 1);
+
+ tMcMiIdxJoin.connectTo(0, sink, 0);
+
+ return new WayangPlan(sink);
+ }
+
+ public static void main(String[] args) throws IOException, URISyntaxException {
+ try {
+ if (args.length == 0) {
+ System.err.print("Usage: [,]* ");
+ System.exit(1);
+ }
+
+ List, MovieInfo>> collector = new LinkedList<>();
+ WayangPlan wayangPlan = createWayangPlan(args[1], collector);
+
+ Configuration config = new Configuration();
+ config.setProperty("spark.master", "spark://spark-cluster:7077");
+ config.setProperty("spark.app.name", "JOB Query");
+ config.setProperty("spark.rpc.message.maxSize", "2047");
+ config.setProperty("spark.executor.memory", "42g");
+ config.setProperty("spark.executor.cores", "4");
+ config.setProperty("spark.executor.instances", "2");
+ config.setProperty("spark.default.parallelism", "8");
+ config.setProperty("spark.driver.maxResultSize", "16g");
+ config.setProperty("spark.shuffle.service.enabled", "true");
+ config.setProperty("spark.dynamicAllocation.enabled", "true");
+ config.setProperty("spark.dynamicAllocation.minExecutors", "2");
+ config.setProperty("wayang.flink.mode.run", "distribution");
+ config.setProperty("wayang.flink.parallelism", "1");
+ config.setProperty("wayang.flink.master", "flink-cluster");
+ config.setProperty("wayang.flink.port", "7071");
+ config.setProperty("wayang.flink.rest.client.max-content-length", "200MiB");
+ config.setProperty("wayang.flink.collect.path", "file:///work/lsbo-paper/data/flink-data");
+ //config.setProperty("wayang.flink.collect.path", "file:///tmp/flink-data");
+ config.setProperty("wayang.ml.experience.enabled", "false");
+ config.setProperty(
+ "wayang.core.optimizer.pruning.strategies",
+ "org.apache.wayang.core.optimizer.enumeration.TopKPruningStrategy"
+ );
+ config.setProperty("wayang.core.optimizer.pruning.topk", "100");
+
+ final WayangContext wayangContext = new WayangContext(config);
+
+ List plugins = JavaConversions.seqAsJavaList(Parameters.loadPlugins(args[0]));
+ plugins.stream().forEach(plug -> wayangContext.register(plug));
+
+ String[] jars = ArrayUtils.addAll(
+ ReflectionUtils.getAllJars(JOBLightQuery1.class),
+ ReflectionUtils.getLibs(JOBLightQuery1.class)
+ );
+
+ wayangContext.execute(wayangPlan, jars);
+
+ System.out.printf("Result size: %d\n", collector.size());
+ } catch (Exception e) {
+ System.err.println("App failed.");
+ e.printStackTrace();
+ System.exit(4);
+ }
+ }
+
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/JOBenchmark.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/JOBenchmark.java
new file mode 100644
index 000000000..897ef77a9
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/JOBenchmark.java
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml.benchmarks;
+
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.core.plan.wayangplan.PlanTraversal;
+import org.apache.wayang.core.plan.wayangplan.Operator;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.basic.operators.JoinOperator;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.spark.Spark;
+import org.apache.wayang.flink.Flink;
+import org.apache.wayang.postgres.Postgres;
+import org.apache.wayang.ml.MLContext;
+import org.apache.wayang.spark.Spark;
+import org.apache.logging.log4j.Level;
+import org.apache.wayang.api.DataQuanta;
+import org.apache.wayang.api.PlanBuilder;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.apps.util.Parameters;
+import org.apache.wayang.core.plugin.Plugin;
+import org.apache.wayang.ml.costs.PairwiseCost;
+import org.apache.wayang.ml.costs.PointwiseCost;
+import org.apache.wayang.core.plan.wayangplan.Operator;
+import org.apache.wayang.core.plan.wayangplan.OutputSlot;
+import org.apache.wayang.core.plan.wayangplan.PlanTraversal;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.basic.operators.TextFileSource;
+import org.apache.wayang.basic.operators.TableSource;
+import org.apache.wayang.basic.operators.MapOperator;
+import org.apache.wayang.basic.data.Record;
+import org.apache.wayang.core.util.ExplainUtils;
+import org.apache.wayang.api.sql.calcite.utils.PrintUtils;
+import org.apache.wayang.api.sql.context.SqlContext;
+import org.apache.wayang.apps.tpch.queries.Query1Wayang;
+import org.apache.wayang.api.DataQuanta;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.calcite.sql.parser.SqlParseException;
+
+import java.io.StringWriter;
+import java.lang.reflect.Constructor;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import scala.collection.Seq;
+import scala.collection.JavaConversions;
+import java.util.Collection;
+import java.io.BufferedWriter;
+import java.io.FileWriter;
+
+public class JOBenchmark {
+
+ /**
+ * 0: platforms
+ * 1: Data directory
+ * 2: Directory to write timings to
+ * 3: query path
+ * 4: model type
+ * 5: model path
+ * 6: experience path
+ */
+ public static String psqlUser = "ucloud";
+ public static String psqlPassword = "ucloud";
+
+ public static void main(String[] args) {
+ try {
+ List pluginsSeq = JavaConversions.seqAsJavaList(Parameters.loadPlugins(args[0]));
+ ArrayList plugins = new ArrayList<>(pluginsSeq);
+ plugins.add(Java.testSuitePlugin());
+ plugins.add(Java.channelConversionPlugin());
+
+ Configuration config = new Configuration();
+ String modelType = "";
+
+ config.setProperty("spark.master", "spark://spark-cluster:7077");
+ config.setProperty("spark.app.name", "JOB Query");
+ config.setProperty("spark.rpc.message.maxSize", "2047");
+ config.setProperty("spark.executor.memory", "42g");
+ config.setProperty("spark.executor.cores", "4");
+ config.setProperty("spark.executor.instances", "2");
+ config.setProperty("spark.default.parallelism", "8");
+ config.setProperty("spark.driver.maxResultSize", "16g");
+ config.setProperty("spark.shuffle.service.enabled", "true");
+ config.setProperty("spark.dynamicAllocation.enabled", "true");
+ config.setProperty("spark.dynamicAllocation.minExecutors", "2");
+ config.setProperty("wayang.flink.mode.run", "distribution");
+ config.setProperty("wayang.flink.parallelism", "1");
+ config.setProperty("wayang.flink.master", "flink-cluster");
+ config.setProperty("wayang.flink.port", "7071");
+ config.setProperty("wayang.flink.rest.client.max-content-length", "200MiB");
+ config.setProperty("wayang.flink.collect.path", "file:///work/lsbo-paper/data/flink-data");
+ //config.setProperty("wayang.flink.collect.path", "file:///tmp/flink-data");
+ config.setProperty("wayang.ml.experience.enabled", "false");
+ config.setProperty(
+ "wayang.core.optimizer.pruning.strategies",
+ "org.apache.wayang.core.optimizer.enumeration.TopKPruningStrategy"
+ );
+ config.setProperty("wayang.core.optimizer.pruning.topk", "100");
+
+ final String calciteModel = "{\n" +
+ " \"version\": \"1.0\",\n" +
+ " \"defaultSchema\": \"wayang\",\n" +
+ " \"schemas\": [\n" +
+ " {\n" +
+ " \"name\": \"postgres\",\n" +
+ " \"type\": \"custom\",\n" +
+ " \"factory\": \"org.apache.wayang.api.sql.calcite.jdbc.JdbcSchema$Factory\",\n" +
+ " \"operand\": {\n" +
+ " \"jdbcDriver\": \"org.postgresql.Driver\",\n" +
+ " \"jdbcUrl\": \"jdbc:postgresql://job:5432/job\",\n" +
+ " \"jdbcUser\": \"" + psqlUser + "\",\n" +
+ " \"jdbcPassword\": \"" + psqlPassword + "\"\n" +
+ " }\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+
+ config.setProperty("org.apache.calcite.sql.parser.parserTracing", "true");
+ config.setProperty("wayang.calcite.model", calciteModel);
+ config.setProperty("wayang.postgres.jdbc.url", "jdbc:postgresql://job:5432/job");
+ config.setProperty("wayang.postgres.jdbc.user", psqlUser);
+ config.setProperty("wayang.postgres.jdbc.password", psqlPassword);
+
+ if (args.length > 4) {
+ modelType = args[4];
+ }
+
+ if (args.length > 6) {
+ JOBenchmark.setMLModel(config, modelType, args[5], args[6]);
+ }
+
+ // Take the query name
+ String fileName = args[3].substring(args[3].lastIndexOf("/") + 1);
+ String queryName = fileName.substring(0, fileName.lastIndexOf("."));
+
+ String executionTimeFile = args[2] + "query-executions-" + queryName;
+ String optimizationTimeFile = args[2] + "query-optimizations-" + queryName;
+
+ if (!"".equals(modelType)) {
+ executionTimeFile += "-" + modelType;
+ optimizationTimeFile += "-" + modelType;
+ }
+
+ config.setProperty(
+ "wayang.ml.executions.file",
+ executionTimeFile + ".txt"
+ );
+
+ config.setProperty(
+ "wayang.ml.optimizations.file",
+ optimizationTimeFile + ".txt"
+ );
+
+ final MLContext wayangContext = new MLContext(config);
+ plugins.stream().forEach(plug -> wayangContext.register(plug));
+
+
+ String[] jars = ArrayUtils.addAll(
+ ReflectionUtils.getAllJars(JOBenchmark.class),
+ ReflectionUtils.getLibs(JOBenchmark.class)
+ );
+
+ jars = ArrayUtils.addAll(
+ jars,
+ ReflectionUtils.getAllJars(org.apache.calcite.runtime.SqlFunctions.class)
+ );
+
+ try {
+ WayangPlan plan = IMDBJOBenchmark.getWayangPlan(args[3], config, plugins.toArray(Plugin[]::new), jars);
+
+ IMDBJOBenchmark.setSources(plan, args[1]);
+
+ //ExplainUtils.parsePlan(plan, true);
+
+ //Set sink to be on Java
+ //((LinkedList )plan.getSinks()).get(0).addTargetPlatform(Java.platform());
+ //
+ /*
+ FileWriter fw = new FileWriter(
+ "/var/www/html/data/benchmarks/operators.txt",
+ true
+ );
+ BufferedWriter writer = new BufferedWriter(fw);
+
+
+ System.out.println("Operators: " + operators.size());
+
+ writer.write(args[3] + ": " + operators.size());
+ writer.newLine();
+ writer.flush();
+ writer.close();&*/
+ wayangContext.setLogLevel(Level.DEBUG);
+
+ if (!"vae".equals(modelType) && !"bvae".equals(modelType)) {
+ System.out.println("Executing query " + args[3]);
+ wayangContext.execute(plan, jars);
+ System.out.println("Finished execution");
+ } else {
+ System.out.println("Using vae cost model");
+ System.out.println("Executing query " + args[3]);
+ wayangContext.executeVAE(plan, jars);
+ System.out.println("Finished execution");
+ }
+ } catch (SqlParseException sqlE) {
+ sqlE.printStackTrace();
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ private static void setMLModel(Configuration config, String modelType, String path, String experiencePath) {
+ config.setProperty(
+ "wayang.ml.model.file",
+ path
+ );
+
+ switch(modelType) {
+ case "cost":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-cost.txt");
+
+ config.setCostModel(new PointwiseCost());
+ System.out.println("Using cost ML Model");
+
+ break;
+ case "pairwise":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-pairwise.txt");
+ config.setCostModel(new PairwiseCost());
+
+ System.out.println("Using pairwise ML Model");
+ break;
+ case "bvae":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-bvae.txt");
+
+ System.out.println("Using bvae ML Model");
+ break;
+ case "vae":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-vae.txt");
+
+ System.out.println("Using vae ML Model");
+ break;
+ default:
+ System.out.println("Using default cost Model");
+ break;
+ }
+
+ }
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/LSBORunner.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/LSBORunner.java
new file mode 100644
index 000000000..0ba183a2a
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/LSBORunner.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml.benchmarks;
+
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.api.Job;
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.core.plan.wayangplan.Operator;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.core.util.Tuple;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.ml.MLContext;
+import org.apache.wayang.spark.Spark;
+
+import org.apache.wayang.api.python.executor.ProcessFeeder;
+import org.apache.wayang.api.python.executor.ProcessReceiver;
+
+import org.apache.wayang.apps.util.Parameters;
+import org.apache.wayang.core.util.ExplainUtils;
+import org.apache.wayang.core.plugin.Plugin;
+import org.apache.wayang.ml.costs.PairwiseCost;
+import org.apache.wayang.ml.costs.PointwiseCost;
+import org.apache.wayang.ml.training.LSBO;
+import org.apache.wayang.ml.training.TPCH;
+import org.apache.wayang.apps.tpch.queries.Query1Wayang;
+import org.apache.wayang.apps.tpch.queries.Query3;
+import org.apache.wayang.apps.tpch.queries.Query5;
+import org.apache.wayang.apps.tpch.queries.Query6;
+import org.apache.wayang.apps.tpch.queries.Query10;
+import org.apache.wayang.apps.tpch.queries.Query12;
+import org.apache.wayang.apps.tpch.queries.Query14;
+import org.apache.wayang.apps.tpch.queries.Query19;
+import org.apache.wayang.ml.encoding.OneHotMappings;
+import org.apache.wayang.ml.encoding.TreeEncoder;
+import org.apache.wayang.ml.encoding.TreeNode;
+import org.apache.wayang.api.DataQuanta;
+import org.apache.wayang.api.PlanBuilder;
+import org.apache.wayang.ml.training.GeneratableJob;
+import org.apache.wayang.ml.benchmarks.IMDBJOBenchmark;
+import org.apache.wayang.ml.benchmarks.JOBenchmark;
+import org.apache.wayang.ml.util.Jobs;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.lang.reflect.Constructor;
+import java.util.HashMap;
+import java.util.List;
+import java.util.LinkedList;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.io.File;
+import java.io.IOException;
+import java.time.Instant;
+import java.time.Duration;
+import java.net.InetAddress;
+import java.net.Socket;
+import java.net.ServerSocket;
+
+import scala.collection.Seq;
+import scala.collection.JavaConversions;
+import com.google.protobuf.ByteString;
+
+/**
+ * TODO:
+ * - Move this to a class so that LSBO is a utility function
+ * -- Takes wayang plan as input
+ * -- Encodes it and sends it to python
+ * -- Receives set of encoded strings from latent space
+ * -- Executes each of those on the original plan
+ */
+public class LSBORunner {
+
+ public static String psqlUser = "ucloud";
+ public static String psqlPassword = "ucloud";
+
+ public static void main(String[] args) {
+ List plugins = JavaConversions.seqAsJavaList(Parameters.loadPlugins(args[0]));
+ Configuration config = new Configuration();
+
+ final String calciteModel = "{\n" +
+ " \"version\": \"1.0\",\n" +
+ " \"defaultSchema\": \"wayang\",\n" +
+ " \"schemas\": [\n" +
+ " {\n" +
+ " \"name\": \"postgres\",\n" +
+ " \"type\": \"custom\",\n" +
+ " \"factory\": \"org.apache.wayang.api.sql.calcite.jdbc.JdbcSchema$Factory\",\n" +
+ " \"operand\": {\n" +
+ " \"jdbcDriver\": \"org.postgresql.Driver\",\n" +
+ " \"jdbcUrl\": \"jdbc:postgresql://job:5432/job\",\n" +
+ " \"jdbcUser\": \"" + psqlUser + "\",\n" +
+ " \"jdbcPassword\": \"" + psqlPassword + "\"\n" +
+ " }\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+
+ config.load(ReflectionUtils.loadResource("wayang-api-python-defaults.properties"));
+ config.setProperty("org.apache.calcite.sql.parser.parserTracing", "true");
+ config.setProperty("wayang.calcite.model", calciteModel);
+ config.setProperty("wayang.postgres.jdbc.url", "jdbc:postgresql://job:5432/job");
+ config.setProperty("wayang.postgres.jdbc.user", psqlUser);
+ config.setProperty("wayang.postgres.jdbc.password", psqlPassword);
+
+ config.setProperty("spark.master", "spark://spark-cluster:7077");
+ config.setProperty("spark.app.name", "JOB Query");
+ config.setProperty("spark.rpc.message.maxSize", "2047");
+
+ // Executor memory (on 48GB nodes)
+ config.setProperty("spark.executor.memory", "36g");
+ config.setProperty("spark.executor.memoryOverhead", "4g");
+ config.setProperty("spark.executor.cores", "8");
+
+ // Driver memory (on 48GB master node)
+ config.setProperty("spark.driver.memory", "24g");
+ config.setProperty("spark.driver.memoryOverhead", "2g");
+ config.setProperty("spark.driver.maxResultSize", "8g");
+
+ // Dynamic allocation (drop spark.executor.instances)
+ config.setProperty("spark.dynamicAllocation.enabled", "true");
+ config.setProperty("spark.dynamicAllocation.minExecutors", "2");
+ config.setProperty("spark.dynamicAllocation.maxExecutors", "2");
+ config.setProperty("spark.dynamicAllocation.initialExecutors", "2");
+ config.setProperty("spark.shuffle.service.enabled", "true");
+
+ // Parallelism
+ config.setProperty("spark.default.parallelism", "16"); // 2 executors * 8 cores
+
+ config.setProperty("wayang.flink.mode.run", "distribution");
+ config.setProperty("wayang.flink.parallelism", "1");
+ config.setProperty("wayang.flink.master", "flink-cluster");
+ config.setProperty("wayang.flink.port", "7071");
+ config.setProperty("wayang.flink.rest.client.max-content-length", "200MiB");
+ config.setProperty("wayang.flink.collect.path", "file:///work/lsbo-paper/data/flink-data");
+ //config.setProperty("wayang.flink.collect.path", "file:///tmp/flink-data");
+ config.setProperty("wayang.ml.experience.enabled", "false");
+ /*
+ config.setProperty(
+ "wayang.core.optimizer.pruning.strategies",
+ "org.apache.wayang.core.optimizer.enumeration.TopKPruningStrategy"
+ );
+ config.setProperty("wayang.core.optimizer.pruning.topk", "10000");*/
+
+ String[] jars = ArrayUtils.addAll(
+ ReflectionUtils.getAllJars(LSBORunner.class),
+ ReflectionUtils.getLibs(LSBORunner.class)
+ );
+
+ jars = ArrayUtils.addAll(
+ jars,
+ ReflectionUtils.getAllJars(org.apache.calcite.rel.externalize.RelJson.class)
+ );
+
+
+ /*
+ HashMap plans = TPCH.createPlans(args[1]);
+ WayangPlan plan = plans.get("query" + args[2]);*/
+
+ try {
+ //WayangPlan plan = getTPCHPlan(args[0], args[1], Integer.parseInt(args[2]));
+ WayangPlan plan = getJOBPlan(plugins, args[1], config, args[2], jars);
+
+ //Set sink to be on Java
+ /*
+ ((LinkedList) plan.getSinks())
+ .get(0)
+ .addTargetPlatform(Java.platform());*/
+
+ LSBO.process(plan, config, plugins, jars);
+ } catch(Exception e) {
+ System.out.println(e.getMessage());
+ }
+
+ }
+
+ private static WayangPlan getJOBPlan(List plugins, String dataPath, Configuration config, String queryPath, String[] jars) {
+ try {
+ WayangPlan plan = IMDBJOBenchmark.getWayangPlan(queryPath, config, plugins.toArray(Plugin[]::new), jars);
+ IMDBJOBenchmark.setSources(plan, dataPath);
+
+ return plan;
+ } catch (Exception e) {
+ e.printStackTrace();
+
+ return null;
+ }
+ }
+
+ private static WayangPlan getTPCHPlan(String platforms, String dataPath, int query) {
+ try {
+ Class extends GeneratableJob> job = Jobs.getJob(query);
+
+ Constructor> cnstr = job.getDeclaredConstructors()[0];
+ GeneratableJob createdJob = (GeneratableJob) cnstr.newInstance();
+ String[] jobArgs = {platforms, dataPath};
+ DataQuanta> quanta = createdJob.buildPlan(jobArgs);
+ PlanBuilder builder = quanta.getPlanBuilder();
+ WayangPlan plan = builder.build();
+
+ return plan;
+ } catch (Exception e) {
+ e.printStackTrace();
+
+ return null;
+ }
+ }
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/TPCHBenchmarks.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/TPCHBenchmarks.java
new file mode 100644
index 000000000..3024ffde8
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/TPCHBenchmarks.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml.benchmarks;
+
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.ml.MLContext;
+import org.apache.wayang.spark.Spark;
+
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.apps.util.Parameters;
+import org.apache.wayang.core.plugin.Plugin;
+import org.apache.wayang.ml.costs.PairwiseCost;
+import org.apache.wayang.ml.costs.PointwiseCost;
+import org.apache.wayang.ml.training.TPCH;
+import org.apache.wayang.apps.tpch.queries.Query1Wayang;
+import org.apache.wayang.apps.tpch.queries.Query3;
+import org.apache.wayang.apps.tpch.queries.Query5;
+import org.apache.wayang.apps.tpch.queries.Query6;
+import org.apache.wayang.apps.tpch.queries.Query10;
+import org.apache.wayang.apps.tpch.queries.Query12;
+import org.apache.wayang.apps.tpch.queries.Query14;
+import org.apache.wayang.apps.tpch.queries.Query19;
+import org.apache.wayang.basic.operators.TextFileSource;
+import org.apache.wayang.core.plan.wayangplan.PlanTraversal;
+import org.apache.wayang.core.plan.wayangplan.OutputSlot;
+import org.apache.wayang.basic.operators.*;
+import org.apache.wayang.apps.tpch.data.OrderTuple;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.util.HashMap;
+import java.util.List;
+import scala.collection.Seq;
+import scala.collection.JavaConversions;
+
+public class TPCHBenchmarks {
+
+ /**
+ * 0: platforms
+ * 1: TPCH data set directory path
+ * 2: Directory to write timings to
+ * 3: query number
+ * 4: model type
+ * 5: model path
+ * 6: experience path
+ */
+ public static void main(String[] args) {
+ try {
+ List plugins = JavaConversions.seqAsJavaList(Parameters.loadPlugins(args[0]));
+ Configuration config = new Configuration();
+ String modelType = "";
+
+ config.setProperty("spark.master", "spark://spark-cluster:7077");
+ config.setProperty("spark.app.name", "TPC-H Benchmark Query " + args[3]);
+ config.setProperty("spark.executor.memory", "16g");
+ config.setProperty("spark.eventLog.enabled", "true");
+ config.setProperty("wayang.flink.mode.run", "distribution");
+ config.setProperty("wayang.flink.parallelism", "8");
+ config.setProperty("wayang.flink.master", "flink-cluster");
+ config.setProperty("wayang.flink.port", "7071");
+ config.setProperty("spark.app.name", "TPC-H Benchmark Query " + args[3]);
+ config.setProperty("spark.executor.memory", "16g");
+ config.setProperty("wayang.ml.experience.enabled", "false");
+
+ if (args.length > 4) {
+ modelType = args[4];
+ }
+
+ if (args.length > 6) {
+ TPCHBenchmarks.setMLModel(config, modelType, args[5], args[6]);
+ }
+
+ String executionTimeFile = args[2] + "query" + args[3] + "-executions";
+ String optimizationTimeFile = args[2] + "query" + args[3] + "-optimizations";
+
+ if (!"".equals(modelType)) {
+ executionTimeFile += "-" + modelType;
+ optimizationTimeFile += "-" + modelType;
+ }
+
+ config.setProperty(
+ "wayang.ml.executions.file",
+ executionTimeFile + ".txt"
+ );
+
+ config.setProperty(
+ "wayang.ml.optimizations.file",
+ optimizationTimeFile + ".txt"
+ );
+
+ final MLContext wayangContext = new MLContext(config);
+ plugins.stream().forEach(plug -> wayangContext.register(plug));
+
+ HashMap plans = TPCH.createPlans(args[1]);
+ WayangPlan plan = plans.get("query" + args[3]);
+
+ plan.collectReachableTopLevelSources().forEach(source -> {
+ if (source instanceof TextFileSource) {
+
+ String inputUrl = ((TextFileSource) source).getInputUrl();
+ System.out.println("SAUCE: " + inputUrl);
+
+ if (inputUrl.equals("file:///opt/data/orders.tbl")) {
+ TextFileSource orderText = new TextFileSource(inputUrl, "UTF-8");
+
+ MapOperator orderParser = new MapOperator<>(
+ (line) -> new OrderTuple.Parser().parse(line, '|'),
+ String.class,
+ OrderTuple.class
+ );
+ orderText.connectTo(0, orderParser, 0);
+
+ OutputSlot.stealConnections(source.getOutput(0).getOccupiedSlots().get(0).getOwner(), orderParser);
+ }
+ }
+ });
+
+ String[] jars = ArrayUtils.addAll(
+ ReflectionUtils.getAllJars(TPCHBenchmarks.class),
+ ReflectionUtils.getAllJars(org.apache.calcite.rel.externalize.RelJson.class)
+ );
+
+ System.out.println(modelType);
+ if (!"vae".equals(modelType) && !"bvae".equals(modelType)) {
+ System.out.println("Executing query " + args[3]);
+ wayangContext.execute(plan, jars);
+ System.out.println("Finished execution");
+ } else {
+ System.out.println("Using vae cost model");
+ System.out.println("Executing query " + args[3]);
+ wayangContext.executeVAE(plan, jars);
+ System.out.println("Finished execution");
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ private static void setMLModel(Configuration config, String modelType, String path, String experiencePath) {
+ config.setProperty(
+ "wayang.ml.model.file",
+ path
+ );
+
+ switch(modelType) {
+ case "cost":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-cost.txt");
+
+ config.setCostModel(new PointwiseCost());
+ System.out.println("Using cost ML Model");
+
+ break;
+ case "pairwise":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-pairwise.txt");
+ config.setCostModel(new PairwiseCost());
+
+ System.out.println("Using pairwise ML Model");
+ break;
+ case "bvae":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-bvae.txt");
+
+ System.out.println("Using bvae ML Model");
+ break;
+ case "vae":
+ config.setProperty("wayang.ml.experience.enabled", "true");
+ config.setProperty("wayang.ml.experience.file", experiencePath + "experience-vae.txt");
+
+ System.out.println("Using vae ML Model");
+ break;
+ default:
+ System.out.println("Using default cost Model");
+ break;
+ }
+
+ }
+
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/WordCount.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/WordCount.java
new file mode 100644
index 000000000..47dff35a4
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/WordCount.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml.benchmarks;
+
+import org.apache.wayang.core.api.Configuration;
+import org.apache.wayang.basic.data.Tuple2;
+import org.apache.wayang.basic.operators.*;
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.core.function.FlatMapDescriptor;
+import org.apache.wayang.core.function.ReduceDescriptor;
+import org.apache.wayang.core.function.TransformationDescriptor;
+import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.core.types.DataSetType;
+import org.apache.wayang.core.types.DataUnitType;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.java.platform.JavaPlatform;
+import org.apache.wayang.spark.Spark;
+import org.apache.wayang.spark.platform.SparkPlatform;
+import org.apache.wayang.ml.MLContext;
+import org.apache.wayang.ml.costs.MLCost;
+import org.apache.wayang.ml.costs.PairwiseCost;
+import org.apache.wayang.ml.costs.PointwiseCost;
+import org.apache.logging.log4j.Level;
+import org.apache.wayang.apps.util.Parameters;
+import org.apache.wayang.core.plugin.Plugin;
+
+import scala.collection.Seq;
+import scala.collection.JavaConversions;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.List;
+
+/**
+ * Example Apache Wayang (incubating) App that does a word count -- the Hello World of Map/Reduce-like systems.
+ */
+public class WordCount {
+
+ /**
+ * Creates the {@link WayangPlan} for the word count app.
+ *
+ * @param inputFileUrl the file whose words should be counted
+ */
+ public static WayangPlan createWayangPlan(String inputFileUrl, Collection> collector) throws URISyntaxException, IOException {
+ // Assignment mode: none.
+
+ TextFileSource textFileSource = new TextFileSource(inputFileUrl);
+ textFileSource.setName("Load file");
+
+ // for each line (input) output an iterator of the words
+ FlatMapOperator flatMapOperator = new FlatMapOperator<>(
+ new FlatMapDescriptor<>(line -> Arrays.asList(line.split("\\W+")),
+ String.class,
+ String.class,
+ new ProbabilisticDoubleInterval(100, 10000, 0.8)
+ )
+ );
+ flatMapOperator.setName("Split words");
+
+ FilterOperator filterOperator = new FilterOperator<>(str -> !str.isEmpty(), String.class);
+ filterOperator.setName("Filter empty words");
+
+
+ // for each word transform it to lowercase and output a key-value pair (word, 1)
+ MapOperator> mapOperator = new MapOperator<>(
+ new TransformationDescriptor<>(word -> new Tuple2<>(word.toLowerCase(), 1),
+ DataUnitType.createBasic(String.class),
+ DataUnitType.createBasicUnchecked(Tuple2.class)
+ ), DataSetType.createDefault(String.class),
+ DataSetType.createDefaultUnchecked(Tuple2.class)
+ );
+ mapOperator.setName("To lower case, add counter");
+
+
+ // groupby the key (word) and add up the values (frequency)
+ ReduceByOperator, String> reduceByOperator = new ReduceByOperator<>(
+ new TransformationDescriptor<>(pair -> pair.field0,
+ DataUnitType.createBasicUnchecked(Tuple2.class),
+ DataUnitType.createBasic(String.class)), new ReduceDescriptor<>(
+ ((a, b) -> {
+ a.field1 += b.field1;
+ return a;
+ }), DataUnitType.createGroupedUnchecked(Tuple2.class),
+ DataUnitType.createBasicUnchecked(Tuple2.class)
+ ), DataSetType.createDefaultUnchecked(Tuple2.class)
+ );
+ reduceByOperator.setName("Add counters");
+
+
+ // write results to a sink
+ LocalCallbackSink> sink = LocalCallbackSink.createCollectingSink(
+ collector,
+ DataSetType.createDefaultUnchecked(Tuple2.class)
+ );
+ sink.setName("Collect result");
+
+ // Build Rheem plan by connecting operators
+ textFileSource.connectTo(0, flatMapOperator, 0);
+ flatMapOperator.connectTo(0, filterOperator, 0);
+ filterOperator.connectTo(0, mapOperator, 0);
+ mapOperator.connectTo(0, reduceByOperator, 0);
+ reduceByOperator.connectTo(0, sink, 0);
+
+ return new WayangPlan(sink);
+ }
+
+ public static void main(String[] args) throws IOException, URISyntaxException {
+ try {
+ if (args.length == 0) {
+ System.err.print("Usage: [,]* ");
+ System.exit(1);
+ }
+
+ List> collector = new LinkedList<>();
+ WayangPlan wayangPlan = createWayangPlan(args[1], collector);
+
+ Configuration config = new Configuration();
+ /*
+ config.setProperty(
+ "wayang.ml.model.file",
+ "/var/www/html/wayang-plugins/wayang-ml/src/main/resources/pairwise.onnx"
+ );*/
+
+ config.setProperty(
+ "wayang.ml.model.file",
+ "/var/www/html/wayang-plugins/wayang-ml/src/main/resources/cost.onnx"
+ );
+
+ config.setProperty(
+ "wayang.core.log.enabled",
+ "false"
+ );
+
+ //config.setCostModel(new PairwiseCost());
+ config.setCostModel(new PointwiseCost());
+ final MLContext wayangContext = new MLContext(config);
+ //wayangContext.setLogLevel(Level.DEBUG);
+
+ List plugins = JavaConversions.seqAsJavaList(Parameters.loadPlugins(args[0]));
+ plugins.stream().forEach(plug -> wayangContext.register(plug));
+
+ wayangContext.execute(wayangPlan, ReflectionUtils.getDeclaringJar(WordCount.class), ReflectionUtils.getDeclaringJar(JavaPlatform.class));
+
+ collector.sort((t1, t2) -> Integer.compare(t2.field1, t1.field1));
+ System.out.printf("Found %d words:\n", collector.size()); //collector.forEach(wc -> System.out.printf("%dx %s\n", wc.field1, wc.field0));
+ } catch (Exception e) {
+ System.err.println("App failed.");
+ e.printStackTrace();
+ System.exit(4);
+ }
+ }
+
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query1.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query1.java
new file mode 100644
index 000000000..817265ae5
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query1.java
@@ -0,0 +1,438 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.wayang.ml.benchmarks.job.complex;
+
+import org.apache.wayang.apps.imdb.data.*;
+import org.apache.wayang.basic.operators.*;
+import org.apache.wayang.core.api.WayangContext;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.java.Java;
+import org.apache.wayang.java.platform.JavaPlatform;
+import org.apache.wayang.spark.Spark;
+import org.apache.wayang.spark.platform.SparkPlatform;
+import org.apache.wayang.basic.data.Tuple2;
+import org.apache.wayang.basic.operators.JoinOperator;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.basic.operators.GlobalReduceOperator;
+import org.apache.wayang.core.types.DataSetType;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+
+/*
+ * SELECT
+ * MIN(cn.name) AS company_name,
+ * MIN(lt.link) AS link_type,
+ * MIN(t.title) AS western_follow_up
+ * FROM
+ * company_name AS cn,
+ * company_type AS ct,
+ * keyword AS k,
+ * link_type AS lt,
+ * movie_companies AS mc,
+ * movie_info AS mi,
+ * movie_keyword AS mk,
+ * movie_link AS ml,
+ * title AS t
+ * WHERE
+ * cn.country_code !='[pl]'
+ * AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%')
+ * AND ct.kind ='production companies'
+ * AND k.keyword ='sequel'
+ * AND lt.link LIKE '%follow%'
+ * AND mc.note IS NULL
+ * AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark')
+ * AND t.production_year BETWEEN 1950 AND 2000
+ * AND lt.id = ml.link_type_id
+ * AND ml.movie_id = t.id
+ * AND t.id = mk.movie_id
+ * AND mk.keyword_id = k.id
+ * AND t.id = mc.movie_id
+ * AND mc.company_type_id = ct.id
+ * AND mc.company_id = cn.id
+ * AND ml.movie_id = mk.movie_id
+ * AND mk.movie_id = mc.movie_id
+ * AND ml.movie_id = mi.movie_id
+ * AND mc.movie_id = mi.movie_id
+ * AND cn.name_pcode_nf = cn.name_pcode_sf
+ * AND mi.movie_id = t.id
+ * AND ml.movie_id = mc.movie_id
+ * AND mk.movie_id = mi.movie_id;
+ */
+public class Query1 {
+
+ public static WayangPlan getWayangPlan(String dataPath, Collection collector){
+ TextFileSource companyNameText = new TextFileSource(dataPath + "company_name.csv", "UTF-8");
+ TextFileSource companyTypeText = new TextFileSource(dataPath + "company_type.csv", "UTF-8");
+ TextFileSource keywordText = new TextFileSource(dataPath + "keyword.csv", "UTF-8");
+ TextFileSource linkTypeText = new TextFileSource(dataPath + "link_type.csv", "UTF-8");
+ TextFileSource movieCompaniesText = new TextFileSource(dataPath + "movie_companies.csv", "UTF-8");
+ TextFileSource movieInfoText = new TextFileSource(dataPath + "movie_info.csv", "UTF-8");
+ TextFileSource movieKeywordText = new TextFileSource(dataPath + "movie_keyword.csv", "UTF-8");
+ TextFileSource movieLinkText = new TextFileSource(dataPath + "movie_link.csv", "UTF-8");
+ TextFileSource titleText = new TextFileSource(dataPath + "title.csv", "UTF-8");
+
+ MapOperator cnParser = new MapOperator(
+ (line) -> CompanyName.parseCsv(line),
+ String.class,
+ CompanyName.class
+ );
+
+ MapOperator ctParser = new MapOperator(
+ (line) -> CompanyType.parseCsv(line),
+ String.class,
+ CompanyType.class
+ );
+
+ MapOperator kParser = new MapOperator(
+ (line) -> Keyword.parseCsv(line),
+ String.class,
+ Keyword.class
+ );
+
+ MapOperator ltParser = new MapOperator(
+ (line) -> LinkType.parseCsv(line),
+ String.class,
+ LinkType.class
+ );
+
+ MapOperator mcParser = new MapOperator(
+ (line) -> MovieCompanies.parseCsv(line),
+ String.class,
+ MovieCompanies.class
+ );
+
+ MapOperator miParser = new MapOperator(
+ (line) -> MovieInfo.parseCsv(line),
+ String.class,
+ MovieInfo.class
+ );
+
+ MapOperator mkParser = new MapOperator(
+ (line) -> MovieKeyword.parseCsv(line),
+ String.class,
+ MovieKeyword.class
+ );
+
+ MapOperator mkParserTwo = new MapOperator(mkParser);
+
+ MapOperator mlParser = new MapOperator(
+ (line) -> MovieLink.parseCsv(line),
+ String.class,
+ MovieLink.class
+ );
+
+ MapOperator tParser = new MapOperator(
+ (line) -> Title.parseCsv(line),
+ String.class,
+ Title.class
+ );
+
+ FilterOperator cnFilter = new FilterOperator(
+ (cn) -> cn.countryCode() != "[pl]",
+ CompanyName.class
+ );
+
+ FilterOperator cnFilterTwo = new FilterOperator(
+ (cn) -> cn.name().contains("Film") || cn.name().contains("Warner"),
+ CompanyName.class
+ );
+
+ FilterOperator cnFilterThree = new FilterOperator(cnFilterTwo);
+
+ FilterOperator ctFilter = new FilterOperator(
+ (ct) -> ct.kind().equals("production companies"),
+ CompanyType.class
+ );
+
+ FilterOperator kFilter = new FilterOperator(
+ (k) -> k.keyword().equals("sequel"),
+ Keyword.class
+ );
+
+ FilterOperator ltFilter = new FilterOperator(
+ (lt) -> lt.link().contains("follow"),
+ LinkType.class
+ );
+
+ FilterOperator mcFilter = new FilterOperator(
+ (mc) -> mc.note() == null,
+ MovieCompanies.class
+ );
+
+ FilterOperator mcFilterTwo = new FilterOperator(mcFilter);
+ FilterOperator mcFilterThree = new FilterOperator(mcFilter);
+
+ FilterOperator miFilter = new FilterOperator(
+ (mi) -> Arrays.asList(new String[] {"Sweden", "Norway", "Germany", "Denmark"}).contains(mi.info()),
+ MovieInfo.class
+ );
+
+ FilterOperator miFilterTwo = new FilterOperator(miFilter);
+ FilterOperator miFilterThree = new FilterOperator(miFilter);
+
+ FilterOperator tFilter = new FilterOperator(
+ (t) -> t.productionYear() >= 1950 && t.productionYear() <= 2000,
+ Title.class
+ );
+
+ FilterOperator tFilterTwo = new FilterOperator(tFilter);
+
+ JoinOperator ltMlJoin = new JoinOperator(
+ (lt) -> lt.id(),
+ (ml) -> ml.id(),
+ LinkType.class,
+ MovieLink.class,
+ Integer.class
+ );
+
+ JoinOperator, Title, Integer> ltMlTJoin = new JoinOperator, Title, Integer>(
+ (ltMl) -> ltMl.field1.movieId(),
+ (t) -> t.id(),
+ ReflectionUtils.specify(Tuple2.class),
+ Title.class,
+ Integer.class
+ );
+
+ JoinOperator, Title>, MovieKeyword, Integer> ltMlTMkJoin = new JoinOperator, Title>, MovieKeyword, Integer>(
+ (ltMlT) -> ltMlT.field1.id(),
+ (mk) -> mk.movieId(),
+ ReflectionUtils.specify(Tuple2.class),
+ MovieKeyword.class,
+ Integer.class
+ );
+
+ JoinOperator, Title>, MovieKeyword>, Keyword, Integer> ltMlTMkKJoin = new JoinOperator, Title>, MovieKeyword>, Keyword, Integer>(
+ (ltMlTMk) -> ltMlTMk.field1.keywordId(),
+ (k) -> k.id(),
+ ReflectionUtils.specify(Tuple2.class),
+ Keyword.class,
+ Integer.class
+ );
+
+ JoinOperator, Title>, MovieKeyword>, Keyword>, MovieCompanies, Integer> ltMlTMkKMcJoin = new JoinOperator, Title>, MovieKeyword>, Keyword>, MovieCompanies, Integer>(
+ (ltMlTMkK) -> ltMlTMkK.field0.field1.id(),
+ (mc) -> mc.movieId(),
+ ReflectionUtils.specify(Tuple2.class),
+ MovieCompanies.class,
+ Integer.class
+ );
+
+
+ JoinOperator, Title>, MovieKeyword>, Keyword>, MovieCompanies>, CompanyType, Integer> ltMlTMkKMcCtJoin = new JoinOperator, Title>, MovieKeyword>, Keyword>, MovieCompanies>, CompanyType, Integer>(
+ (ltMlTMkKMc) -> ltMlTMkKMc.field1.companyTypeId(),
+ (ct) -> ct.id(),
+ ReflectionUtils.specify(Tuple2.class),
+ CompanyType.class,
+ Integer.class
+ );
+
+
+ JoinOperator ltMlTMkKMcCtCnJoin = new JoinOperator(
+ (ltMlTMkKMcCt) -> ltMlTMkKMcCt.field0.field1.companyId(),
+ (cn) -> cn.id(),
+ ReflectionUtils.specify(LtMlTMkKMcCt.class),
+ CompanyName.class,
+ Integer.class
+ );
+
+ JoinOperator, MovieKeyword, Integer> ltMlTMkKMcCtCnMkJoin = new JoinOperator, MovieKeyword, Integer>(
+ (ltMlTMkKMcCtCn) -> ltMlTMkKMcCtCn.field0.field0.field0.field0.field1.movieId(),
+ (mk) -> mk.movieId(),
+ ReflectionUtils.specify(Tuple2.class),
+ MovieKeyword.class,
+ Integer.class
+ );
+
+ JoinOperator, MovieKeyword>, MovieCompanies, Integer> ltMlTMkKMcCtCnMkMcJoin = new JoinOperator, MovieKeyword>, MovieCompanies, Integer>(
+ (ltMlTMkKMcCtCnMk) -> ltMlTMkKMcCtCnMk.field1.movieId(),
+ (mc) -> mc.movieId(),
+ ReflectionUtils.specify(Tuple2.class),
+ MovieCompanies.class,
+ Integer.class
+ );
+
+ JoinOperator ltMlTMkKMcCtCnMkMcMiJoin = new JoinOperator(
+ (ltMlTMkKMcCtCnMkMc) -> ltMlTMkKMcCtCnMkMc.field0.field0.field0.field0.field0.field0.field0.field0.field1.movieId(),
+ (mi) -> mi.movieId(),
+ ReflectionUtils.specify(LtMlTMkKMcCtCnMkMc.class),
+ MovieInfo.class,
+ Integer.class
+ );
+
+ JoinOperator, MovieInfo, Integer> ltMlTMkKMcCtCnMkMcMiMiJoin = new JoinOperator, MovieInfo, Integer>(
+ (ltMlTMkKMcCtCnMkMcMi) -> ltMlTMkKMcCtCnMkMcMi.field0.field0.field0.field0.field0.field1.movieId(),
+ (mi) -> mi.movieId(),
+ ReflectionUtils.specify(Tuple2.class),
+ MovieInfo.class,
+ Integer.class
+ );
+
+ JoinOperator, MovieInfo>, CompanyName, String> ltMlTMkKMcCtCnMkMcMiMiCnJoin = new JoinOperator, MovieInfo>, CompanyName, String>(
+ (ltMlTMkKMcCtCnMkMcMiMi) -> ltMlTMkKMcCtCnMkMcMiMi.field0.field0.field0.field0.field1.namePcodeNf(),
+ (cn) -> cn.namePcodeSf(),
+ ReflectionUtils.specify(Tuple2.class),
+ CompanyName.class,
+ String.class
+ );
+
+ JoinOperator ltMlTMkKMcCtCnMkMcMiMiCnTJoin = new JoinOperator(
+ (ltMlTMkKMcCtCnMkMcMiMiCn) -> ltMlTMkKMcCtCnMkMcMiMiCn.field0.field1.movieId(),
+ (t) -> t.id(),
+ ReflectionUtils.specify(LtMlTMkKMcCtCnMkMcMiMiCn.class),
+ Title.class,
+ Integer.class
+ );
+
+ JoinOperator, MovieCompanies, Integer> ltMlTMkKMcCtCnMkMcMiMiCnTMcJoin = new JoinOperator, MovieCompanies, Integer>(
+ (ltMlTMkKMcCtCnMkMcMiMiCnT) -> ltMlTMkKMcCtCnMkMcMiMiCnT.field0.field0.field0.field0.field0.field0.field0.field0.field1.movieId(),
+ (mc) -> mc.movieId(),
+ ReflectionUtils.specify(Tuple2.class),
+ MovieCompanies.class,
+ Integer.class
+ );
+
+ JoinOperator, MovieCompanies>, MovieInfo, Integer> ltMlTMkKMcCtCnMkMcMiMiCnTMcMiJoin = new JoinOperator, MovieCompanies>, MovieInfo, Integer>(
+ (ltMlTMkKMcCtCnMkMcMiMiCnTMc) -> ltMlTMkKMcCtCnMkMcMiMiCnTMc.field0.field0.field0.field0.field0.field0.field1.movieId(),
+ (mi) -> mi.movieId(),
+ ReflectionUtils.specify(Tuple2.class),
+ MovieInfo.class,
+ Integer.class
+ );
+
+
+ ReduceByOperator cnMin = new ReduceByOperator(
+ (tuple) -> tuple.field0.field0.field0.field1.name(),
+ (t1, t2) -> {
+ return t1.field0.field0.field0.field1.name().compareTo(t2.field0.field0.field0.field1.name()) <= 0 ? t1 : t2;
+ },
+ String.class,
+ LtMlTMkKMcCtCnMkMcMiMiCnTMcMi.class
+ );
+
+ ReduceByOperator ltMin = new ReduceByOperator(
+ (tuple) -> tuple.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.link(),
+ (t1, t2) -> {
+ return t1.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.link().compareTo(t2.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.field0.link()) <= 0 ? t1 : t2;
+ },
+ String.class,
+ LtMlTMkKMcCtCnMkMcMiMiCnTMcMi.class
+ );
+
+ ReduceByOperator tMin = new ReduceByOperator(
+ (tuple) -> tuple.field0.field0.field1.title(),
+ (t1, t2) -> {
+ return t1.field0.field0.field1.title().compareTo(t2.field0.field0.field1.title()) <= 0 ? t1 : t2;
+ },
+ String.class,
+ LtMlTMkKMcCtCnMkMcMiMiCnTMcMi.class
+ );
+
+ LocalCallbackSink sink = LocalCallbackSink.createCollectingSink(
+ collector,
+ DataSetType.createDefaultUnchecked(LtMlTMkKMcCtCnMkMcMiMiCnTMcMi.class)
+ );
+
+ //Connect all the operators
+ companyNameText.connectTo(0, cnParser, 0);
+ companyTypeText.connectTo(0, ctParser, 0);
+ keywordText.connectTo(0, kParser, 0);
+ linkTypeText.connectTo(0, ltParser, 0);
+ movieCompaniesText.connectTo(0, mcParser, 0);
+ movieInfoText.connectTo(0, miParser, 0);
+ movieKeywordText.connectTo(0, mkParser, 0);
+ movieLinkText.connectTo(0, mlParser, 0);
+ titleText.connectTo(0, tParser, 0);
+
+ cnParser.connectTo(0, cnFilter, 0);
+ cnFilter.connectTo(0, cnFilterTwo, 0);
+ kParser.connectTo(0, kFilter, 0);
+ ltParser.connectTo(0, ltFilter, 0);
+ mcParser.connectTo(0, mcFilter, 0);
+ tParser.connectTo(0, tFilter, 0);
+ miParser.connectTo(0, miFilter, 0);
+
+ ltFilter.connectTo(0, ltMlJoin, 0);
+ mlParser.connectTo(0, ltMlJoin, 1);
+
+ ltMlJoin.connectTo(0, ltMlTJoin, 0);
+ tFilter.connectTo(0, ltMlTJoin, 1);
+
+ ltMlTJoin.connectTo(0, ltMlTMkJoin, 0);
+ mkParser.connectTo(0, ltMlTMkJoin, 1);
+
+ ltMlTMkJoin.connectTo(0, ltMlTMkKJoin, 0);
+ kFilter.connectTo(0, ltMlTMkKJoin, 1);
+
+ ltMlTMkKJoin.connectTo(0, ltMlTMkKMcJoin, 0);
+ mcFilter.connectTo(0, ltMlTMkKMcJoin, 1);
+
+ ltMlTMkKMcJoin.connectTo(0, ltMlTMkKMcCtJoin, 0);
+ ctFilter.connectTo(0, ltMlTMkKMcCtJoin, 1);
+
+ ltMlTMkKMcCtJoin.connectTo(0, ltMlTMkKMcCtCnJoin, 0);
+ cnFilterTwo.connectTo(0, ltMlTMkKMcCtCnJoin, 1);
+
+ ltMlTMkKMcCtCnJoin.connectTo(0, ltMlTMkKMcCtCnMkJoin, 0);
+ mkParserTwo.connectTo(0, ltMlTMkKMcCtCnMkJoin, 1);
+
+ ltMlTMkKMcCtCnMkJoin.connectTo(0, ltMlTMkKMcCtCnMkMcJoin, 0);
+ mcFilterTwo.connectTo(0, ltMlTMkKMcCtCnMkMcJoin, 1);
+
+ ltMlTMkKMcCtCnMkMcJoin.connectTo(0, ltMlTMkKMcCtCnMkMcMiJoin, 0);
+ miFilter.connectTo(0, ltMlTMkKMcCtCnMkMcMiJoin, 1);
+
+ ltMlTMkKMcCtCnMkMcMiJoin.connectTo(0, ltMlTMkKMcCtCnMkMcMiMiJoin, 0);
+ miFilterTwo.connectTo(0, ltMlTMkKMcCtCnMkMcMiMiJoin, 1);
+
+ ltMlTMkKMcCtCnMkMcMiMiJoin.connectTo(0, ltMlTMkKMcCtCnMkMcMiMiCnJoin, 0);
+ cnFilterThree.connectTo(0, ltMlTMkKMcCtCnMkMcMiMiCnJoin, 1);
+
+ ltMlTMkKMcCtCnMkMcMiMiCnJoin.connectTo(0, ltMlTMkKMcCtCnMkMcMiMiCnTJoin, 0);
+ tFilterTwo.connectTo(0, ltMlTMkKMcCtCnMkMcMiMiCnTJoin, 1);
+
+ ltMlTMkKMcCtCnMkMcMiMiCnTJoin.connectTo(0, ltMlTMkKMcCtCnMkMcMiMiCnTMcJoin, 0);
+ mcFilterThree.connectTo(0, ltMlTMkKMcCtCnMkMcMiMiCnTMcJoin, 1);
+
+ ltMlTMkKMcCtCnMkMcMiMiCnTMcJoin.connectTo(0, ltMlTMkKMcCtCnMkMcMiMiCnTMcMiJoin, 0);
+ miFilterThree.connectTo(0, ltMlTMkKMcCtCnMkMcMiMiCnTMcMiJoin, 1);
+
+ ltMlTMkKMcCtCnMkMcMiMiCnTMcMiJoin.connectTo(0, cnMin, 0);
+ cnMin.connectTo(0, ltMin, 0);
+ ltMin.connectTo(0, tMin , 0);
+ tMin.connectTo(0, sink, 0);
+
+
+ return new WayangPlan(sink);
+ }
+
+ //Some intermediate types for intermediate join results
+ private static class LtMlTMkKMcCt extends Tuple2, Title>, MovieKeyword>, Keyword>, MovieCompanies>, CompanyType>{}
+
+ private static class LtMlTMkKMcCtCnMkMc extends Tuple2, MovieKeyword>, MovieCompanies>{}
+
+ private static class LtMlTMkKMcCtCnMkMcMiMiCn extends Tuple2, MovieInfo>, CompanyName>{}
+
+ private static class LtMlTMkKMcCtCnMkMcMiMiCnTMcMi extends Tuple2, MovieCompanies>, MovieInfo>{}
+
+}
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query2.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query2.java
new file mode 100644
index 000000000..271fb33ae
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query2.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+
+package org.apache.wayang.ml.benchmarks.job.complex;
+
+import org.apache.wayang.apps.imdb.data.*;
+import org.apache.wayang.basic.operators.*;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.core.util.ReflectionUtils;
+import org.apache.wayang.basic.data.Tuple2;
+import org.apache.wayang.basic.operators.JoinOperator;
+import org.apache.wayang.basic.operators.GlobalReduceOperator;
+import org.apache.wayang.core.types.DataSetType;
+
+import java.util.Collection;
+
+/*
+ * SELECT
+ * MIN(cn.name) AS from_company,
+ * MIN(lt.link) AS movie_link_type,
+ * MIN(t.title) AS non_polish_sequel_movie
+ * FROM
+ * company_name AS cn,
+ * company_type AS ct,
+ * keyword AS k,
+ * link_type AS lt,
+ * movie_companies AS mc,
+ * movie_keyword AS mk,
+ * movie_link AS ml,
+ * title AS t
+ * WHERE
+ * cn.country_code != '[pl]'
+ * AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%')
+ * AND ct.kind = 'production companies'
+ * AND k.keyword = 'sequel'
+ * AND lt.link LIKE '%follow%'
+ * AND mc.note IS NULL
+ * AND t.production_year BETWEEN 1950 AND 2000
+ * AND lt.id = ml.link_type_id
+ * AND ml.movie_id = t.id
+ * AND t.id = mk.movie_id
+ * AND mk.keyword_id = k.id
+ * AND t.id = mc.movie_id
+ * AND mc.company_type_id = ct.id
+ * AND mc.company_id = cn.id
+ * AND ml.movie_id = mk.movie_id
+ * AND ml.movie_id = mc.movie_id
+ * AND mk.movie_id = mc.movie_id
+ * AND cn.name_pcode_nf = cn.name_pcode_sf;
+ */
+public class Query2 {
+
+ public static WayangPlan getWayangPlan(String dataPath, Collection collector){
+
+ TextFileSource companyNameText = new TextFileSource(dataPath + "company_name.csv", "UTF-8");
+ TextFileSource companyTypeText = new TextFileSource(dataPath + "company_type.csv", "UTF-8");
+ TextFileSource keywordText = new TextFileSource(dataPath + "keyword.csv", "UTF-8");
+ TextFileSource linkTypeText = new TextFileSource(dataPath + "link_type.csv", "UTF-8");
+ TextFileSource movieCompaniesText = new TextFileSource(dataPath + "movie_companies.csv", "UTF-8");
+ TextFileSource movieKeywordText = new TextFileSource(dataPath + "movie_keyword.csv", "UTF-8");
+ TextFileSource movieLinkText = new TextFileSource(dataPath + "movie_link.csv", "UTF-8");
+ TextFileSource titleText = new TextFileSource(dataPath + "title.csv", "UTF-8");
+
+ MapOperator cnParser =
+ new MapOperator<>(CompanyName::parseCsv, String.class, CompanyName.class);
+
+ MapOperator ctParser =
+ new MapOperator<>(CompanyType::parseCsv, String.class, CompanyType.class);
+
+ MapOperator kParser =
+ new MapOperator<>(Keyword::parseCsv, String.class, Keyword.class);
+
+ MapOperator ltParser =
+ new MapOperator<>(LinkType::parseCsv, String.class, LinkType.class);
+
+ MapOperator mcParser =
+ new MapOperator<>(MovieCompanies::parseCsv, String.class, MovieCompanies.class);
+
+ MapOperator mkParser =
+ new MapOperator<>(MovieKeyword::parseCsv, String.class, MovieKeyword.class);
+
+ MapOperator mlParser =
+ new MapOperator<>(MovieLink::parseCsv, String.class, MovieLink.class);
+
+ MapOperator tParser =
+ new MapOperator<>(Title::parseCsv, String.class, Title.class);
+
+ FilterOperator cnFilter =
+ new FilterOperator<>(cn -> cn.countryCode() != "[pl]", CompanyName.class);
+
+ FilterOperator cnFilterTwo =
+ new FilterOperator<>(cn -> cn.name().contains("Film") || cn.name().contains("Warner"), CompanyName.class);
+
+ FilterOperator cnFilterThree = new FilterOperator<>(cnFilterTwo);
+
+ FilterOperator ctFilter =
+ new FilterOperator<>(ct -> ct.kind().equals("production companies"), CompanyType.class);
+
+ FilterOperator kFilter =
+ new FilterOperator<>(k -> k.keyword().equals("sequel"), Keyword.class);
+
+ FilterOperator ltFilter =
+ new FilterOperator<>(lt -> lt.link().contains("follow"), LinkType.class);
+
+ FilterOperator mcFilter =
+ new FilterOperator<>(mc -> mc.note() == null, MovieCompanies.class);
+
+ FilterOperator tFilter =
+ new FilterOperator<>(t -> t.productionYear() >= 1950 && t.productionYear() <= 2000, Title.class);
+
+ JoinOperator ltMlJoin =
+ new JoinOperator<>(LinkType::id, MovieLink::linkTypeId,
+ LinkType.class, MovieLink.class, Integer.class);
+
+ JoinOperator, Title, Integer> ltMlTJoin =
+ new JoinOperator<>(ltMl -> ltMl.field1.movieId(),
+ Title::id,
+ ReflectionUtils.specify(Tuple2.class),
+ Title.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword, Integer> ltMlTMkJoin =
+ new JoinOperator<>(ltMlT -> ltMlT.field1.id(),
+ MovieKeyword::movieId,
+ ReflectionUtils.specify(Tuple2.class),
+ MovieKeyword.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword>, Keyword, Integer> ltMlTMkKJoin =
+ new JoinOperator<>(ltMlTMk -> ltMlTMk.field1.keywordId(),
+ Keyword::id,
+ ReflectionUtils.specify(Tuple2.class),
+ Keyword.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword>, Keyword>, MovieCompanies, Integer> ltMlTMkKMcJoin =
+ new JoinOperator<>(ltMlTMkK -> ltMlTMkK.field0.field1.id(),
+ MovieCompanies::movieId,
+ ReflectionUtils.specify(Tuple2.class),
+ MovieCompanies.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword>, Keyword>, MovieCompanies>, CompanyType, Integer> ltMlTMkKMcCtJoin =
+ new JoinOperator<>(ltMlTMkKMc -> ltMlTMkKMc.field1.companyTypeId(),
+ CompanyType::id,
+ ReflectionUtils.specify(Tuple2.class),
+ CompanyType.class,
+ Integer.class);
+
+ JoinOperator ltMlTMkKMcCtCnJoin =
+ new JoinOperator<>(ltMlTMkKMcCt -> ltMlTMkKMcCt.field0.field1.companyId(),
+ CompanyName::id,
+ ReflectionUtils.specify(LtMlTMkKMcCt.class),
+ CompanyName.class,
+ Integer.class);
+
+ JoinOperator, CompanyName, String> cnSelfJoin =
+ new JoinOperator<>(tuple -> tuple.field1.namePcodeNf(),
+ CompanyName::namePcodeSf,
+ ReflectionUtils.specify(Tuple2.class),
+ CompanyName.class,
+ String.class);
+
+ ReduceByOperator cnMin =
+ new ReduceByOperator<>(
+ tuple -> tuple.field1.name(),
+ (t1, t2) -> t1.field1.name().compareTo(t2.field1.name()) <= 0 ? t1 : t2,
+ String.class,
+ LtMlTMkKMcCtCn.class
+ );
+
+ ReduceByOperator ltMin =
+ new ReduceByOperator<>(
+ tuple -> tuple.field0.field0.field0.field0.field0.field0.field0.link(),
+ (t1, t2) -> t1.field0.field0.field0.field0.field0.field0.field0.link()
+ .compareTo(t2.field0.field0.field0.field0.field0.field0.field0.link()) <= 0 ? t1 : t2,
+ String.class,
+ LtMlTMkKMcCtCn.class
+ );
+
+ ReduceByOperator tMin =
+ new ReduceByOperator<>(
+ tuple -> tuple.field0.field0.field0.field0.field0.field1.title(),
+ (t1, t2) -> t1.field0.field0.field0.field0.field0.field1.title()
+ .compareTo(t2.field0.field0.field0.field0.field0.field1.title()) <= 0 ? t1 : t2,
+ String.class,
+ LtMlTMkKMcCtCn.class
+ );
+
+ LocalCallbackSink sink =
+ LocalCallbackSink.createCollectingSink(
+ collector,
+ DataSetType.createDefaultUnchecked(LtMlTMkKMcCtCn.class)
+ );
+
+ // Connections
+ companyNameText.connectTo(0, cnParser, 0);
+ companyTypeText.connectTo(0, ctParser, 0);
+ keywordText.connectTo(0, kParser, 0);
+ linkTypeText.connectTo(0, ltParser, 0);
+ movieCompaniesText.connectTo(0, mcParser, 0);
+ movieKeywordText.connectTo(0, mkParser, 0);
+ movieLinkText.connectTo(0, mlParser, 0);
+ titleText.connectTo(0, tParser, 0);
+
+ cnParser.connectTo(0, cnFilter, 0);
+ cnFilter.connectTo(0, cnFilterTwo, 0);
+
+ ctParser.connectTo(0, ctFilter, 0);
+ kParser.connectTo(0, kFilter, 0);
+ ltParser.connectTo(0, ltFilter, 0);
+ mcParser.connectTo(0, mcFilter, 0);
+ tParser.connectTo(0, tFilter, 0);
+
+ ltFilter.connectTo(0, ltMlJoin, 0);
+ mlParser.connectTo(0, ltMlJoin, 1);
+
+ ltMlJoin.connectTo(0, ltMlTJoin, 0);
+ tFilter.connectTo(0, ltMlTJoin, 1);
+
+ ltMlTJoin.connectTo(0, ltMlTMkJoin, 0);
+ mkParser.connectTo(0, ltMlTMkJoin, 1);
+
+ ltMlTMkJoin.connectTo(0, ltMlTMkKJoin, 0);
+ kFilter.connectTo(0, ltMlTMkKJoin, 1);
+
+ ltMlTMkKJoin.connectTo(0, ltMlTMkKMcJoin, 0);
+ mcFilter.connectTo(0, ltMlTMkKMcJoin, 1);
+
+ ltMlTMkKMcJoin.connectTo(0, ltMlTMkKMcCtJoin, 0);
+ ctFilter.connectTo(0, ltMlTMkKMcCtJoin, 1);
+
+ ltMlTMkKMcCtJoin.connectTo(0, ltMlTMkKMcCtCnJoin, 0);
+ cnFilterTwo.connectTo(0, ltMlTMkKMcCtCnJoin, 1);
+
+ ltMlTMkKMcCtCnJoin.connectTo(0, cnSelfJoin, 0);
+ cnFilterThree.connectTo(0, cnSelfJoin, 1);
+
+ cnSelfJoin.connectTo(0, cnMin, 0);
+ cnMin.connectTo(0, ltMin, 0);
+ ltMin.connectTo(0, tMin, 0);
+ tMin.connectTo(0, sink, 0);
+
+ return new WayangPlan(sink);
+ }
+
+ private static class LtMlTMkKMcCt extends Tuple2, Title>, MovieKeyword>, Keyword>,
+ MovieCompanies>, CompanyType> {}
+
+ private static class LtMlTMkKMcCtCn extends Tuple2 {}
+}
+
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query3.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query3.java
new file mode 100644
index 000000000..28181f368
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query3.java
@@ -0,0 +1,284 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.
+ */
+
+package org.apache.wayang.ml.benchmarks.job.complex;
+
+import org.apache.wayang.apps.imdb.data.*;
+import org.apache.wayang.basic.data.Tuple2;
+import org.apache.wayang.basic.operators.*;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.core.types.DataSetType;
+import org.apache.wayang.core.util.ReflectionUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+/*
+ * SELECT
+ * MIN(cn.name) AS company_name,
+ * MIN(lt.link) AS link_type,
+ * MIN(t.title) AS western_follow_up
+ * FROM
+ * company_name AS cn,
+ * company_type AS ct,
+ * keyword AS k,
+ * link_type AS lt,
+ * movie_companies AS mc,
+ * movie_info AS mi,
+ * movie_keyword AS mk,
+ * movie_link AS ml,
+ * title AS t
+ * WHERE
+ * cn.country_code !='[pl]'
+ * AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%')
+ * AND ct.kind ='production companies'
+ * AND k.keyword ='sequel'
+ * AND lt.link LIKE '%follow%'
+ * AND mc.note IS NULL
+ * AND mi.info IN ('Sweden','Norway','Germany','Denmark','Swedish','Denish','Norwegian','German')
+ * AND t.production_year BETWEEN 1950 AND 2000
+ * AND lt.id = ml.link_type_id
+ * AND ml.movie_id = t.id
+ * AND t.id = mk.movie_id
+ * AND mk.keyword_id = k.id
+ * AND t.id = mc.movie_id
+ * AND mc.company_type_id = ct.id
+ * AND mi.movie_id = t.id
+ * AND ml.movie_id = mk.movie_id
+ * AND ml.movie_id = mc.movie_id
+ * AND ml.movie_id = mi.movie_id
+ * AND mk.movie_id = mi.movie_id
+ * AND mc.movie_id = mi.movie_id
+ * AND cn.name_pcode_sf = cn.name_pcode_nf
+ * AND mk.movie_id = mc.movie_id
+ * AND mc.company_id = cn.id;
+ */
+public class Query3 {
+
+ public static WayangPlan getWayangPlan(String dataPath, Collection collector){
+
+ TextFileSource companyNameText = new TextFileSource(dataPath + "company_name.csv", "UTF-8");
+ TextFileSource companyTypeText = new TextFileSource(dataPath + "company_type.csv", "UTF-8");
+ TextFileSource keywordText = new TextFileSource(dataPath + "keyword.csv", "UTF-8");
+ TextFileSource linkTypeText = new TextFileSource(dataPath + "link_type.csv", "UTF-8");
+ TextFileSource movieCompaniesText = new TextFileSource(dataPath + "movie_companies.csv", "UTF-8");
+ TextFileSource movieInfoText = new TextFileSource(dataPath + "movie_info.csv", "UTF-8");
+ TextFileSource movieKeywordText = new TextFileSource(dataPath + "movie_keyword.csv", "UTF-8");
+ TextFileSource movieLinkText = new TextFileSource(dataPath + "movie_link.csv", "UTF-8");
+ TextFileSource titleText = new TextFileSource(dataPath + "title.csv", "UTF-8");
+
+ MapOperator cnParser =
+ new MapOperator<>(CompanyName::parseCsv, String.class, CompanyName.class);
+ MapOperator ctParser =
+ new MapOperator<>(CompanyType::parseCsv, String.class, CompanyType.class);
+ MapOperator kParser =
+ new MapOperator<>(Keyword::parseCsv, String.class, Keyword.class);
+ MapOperator ltParser =
+ new MapOperator<>(LinkType::parseCsv, String.class, LinkType.class);
+ MapOperator mcParser =
+ new MapOperator<>(MovieCompanies::parseCsv, String.class, MovieCompanies.class);
+ MapOperator miParser =
+ new MapOperator<>(MovieInfo::parseCsv, String.class, MovieInfo.class);
+ MapOperator mkParser =
+ new MapOperator<>(MovieKeyword::parseCsv, String.class, MovieKeyword.class);
+ MapOperator mlParser =
+ new MapOperator<>(MovieLink::parseCsv, String.class, MovieLink.class);
+ MapOperator tParser =
+ new MapOperator<>(Title::parseCsv, String.class, Title.class);
+
+ FilterOperator cnFilter =
+ new FilterOperator<>(cn -> cn.countryCode() != "[pl]", CompanyName.class);
+ FilterOperator cnFilterTwo =
+ new FilterOperator<>(cn -> cn.name().contains("Film") || cn.name().contains("Warner"), CompanyName.class);
+ FilterOperator cnFilterThree = new FilterOperator<>(cnFilterTwo);
+
+ FilterOperator ctFilter =
+ new FilterOperator<>(ct -> ct.kind().equals("production companies"), CompanyType.class);
+
+ FilterOperator kFilter =
+ new FilterOperator<>(k -> k.keyword().equals("sequel"), Keyword.class);
+
+ FilterOperator ltFilter =
+ new FilterOperator<>(lt -> lt.link().contains("follow"), LinkType.class);
+
+ FilterOperator mcFilter =
+ new FilterOperator<>(mc -> mc.note() == null, MovieCompanies.class);
+
+ FilterOperator miFilter =
+ new FilterOperator<>(mi ->
+ Arrays.asList("Sweden","Norway","Germany","Denmark",
+ "Swedish","Denish","Norwegian","German")
+ .contains(mi.info()),
+ MovieInfo.class);
+
+ FilterOperator tFilter =
+ new FilterOperator<>(t -> t.productionYear() >= 1950 && t.productionYear() <= 2000, Title.class);
+
+ JoinOperator ltMlJoin =
+ new JoinOperator<>(LinkType::id, MovieLink::linkTypeId,
+ LinkType.class, MovieLink.class, Integer.class);
+
+ JoinOperator, Title, Integer> ltMlTJoin =
+ new JoinOperator<>(ltMl -> ltMl.field1.movieId(),
+ Title::id,
+ ReflectionUtils.specify(Tuple2.class),
+ Title.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword, Integer> ltMlTMkJoin =
+ new JoinOperator<>(ltMlT -> ltMlT.field1.id(),
+ MovieKeyword::movieId,
+ ReflectionUtils.specify(Tuple2.class),
+ MovieKeyword.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword>, Keyword, Integer> ltMlTMkKJoin =
+ new JoinOperator<>(ltMlTMk -> ltMlTMk.field1.keywordId(),
+ Keyword::id,
+ ReflectionUtils.specify(Tuple2.class),
+ Keyword.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword>, Keyword>, MovieCompanies, Integer> ltMlTMkKMcJoin =
+ new JoinOperator<>(ltMlTMkK -> ltMlTMkK.field0.field1.id(),
+ MovieCompanies::movieId,
+ ReflectionUtils.specify(Tuple2.class),
+ MovieCompanies.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword>, Keyword>, MovieCompanies>, CompanyType, Integer> ltMlTMkKMcCtJoin =
+ new JoinOperator<>(ltMlTMkKMc -> ltMlTMkKMc.field1.companyTypeId(),
+ CompanyType::id,
+ ReflectionUtils.specify(Tuple2.class),
+ CompanyType.class,
+ Integer.class);
+
+ JoinOperator ltMlTMkKMcCtCnJoin =
+ new JoinOperator<>(ltMlTMkKMcCt -> ltMlTMkKMcCt.field0.field1.companyId(),
+ CompanyName::id,
+ ReflectionUtils.specify(LtMlTMkKMcCt.class),
+ CompanyName.class,
+ Integer.class);
+
+ JoinOperator, MovieInfo, Integer> ltMlTMkKMcCtCnMiJoin =
+ new JoinOperator<>(tuple ->
+ tuple.field0.field0.field0.field0.field0.field0.field1.movieId(),
+ MovieInfo::movieId,
+ ReflectionUtils.specify(Tuple2.class),
+ MovieInfo.class,
+ Integer.class);
+
+ JoinOperator cnSelfJoin =
+ new JoinOperator<>(tuple -> tuple.field0.field1.namePcodeSf(),
+ CompanyName::namePcodeNf,
+ ReflectionUtils.specify(LtMlTMkKMcCtCnMi.class),
+ CompanyName.class,
+ String.class);
+
+ ReduceByOperator cnMin =
+ new ReduceByOperator<>(
+ tuple -> tuple.field0.field1.name(),
+ (t1, t2) -> t1.field0.field1.name().compareTo(t2.field0.field1.name()) <= 0 ? t1 : t2,
+ String.class,
+ LtMlTMkKMcCtCnMi.class
+ );
+
+ ReduceByOperator ltMin =
+ new ReduceByOperator<>(
+ tuple -> tuple.field0.field0.field0.field0.field0.field0.field0.field0.link(),
+ (t1, t2) -> t1.field0.field0.field0.field0.field0.field0.field0.field0.link()
+ .compareTo(t2.field0.field0.field0.field0.field0.field0.field0.field0.link()) <= 0 ? t1 : t2,
+ String.class,
+ LtMlTMkKMcCtCnMi.class
+ );
+
+ ReduceByOperator tMin =
+ new ReduceByOperator<>(
+ tuple -> tuple.field0.field0.field0.field0.field0.field0.field1.title(),
+ (t1, t2) -> t1.field0.field0.field0.field0.field0.field0.field1.title()
+ .compareTo(t2.field0.field0.field0.field0.field0.field0.field1.title()) <= 0 ? t1 : t2,
+ String.class,
+ LtMlTMkKMcCtCnMi.class
+ );
+
+ LocalCallbackSink sink =
+ LocalCallbackSink.createCollectingSink(
+ collector,
+ DataSetType.createDefaultUnchecked(LtMlTMkKMcCtCnMi.class)
+ );
+
+ // Connections
+ companyNameText.connectTo(0, cnParser, 0);
+ companyTypeText.connectTo(0, ctParser, 0);
+ keywordText.connectTo(0, kParser, 0);
+ linkTypeText.connectTo(0, ltParser, 0);
+ movieCompaniesText.connectTo(0, mcParser, 0);
+ movieInfoText.connectTo(0, miParser, 0);
+ movieKeywordText.connectTo(0, mkParser, 0);
+ movieLinkText.connectTo(0, mlParser, 0);
+ titleText.connectTo(0, tParser, 0);
+
+ cnParser.connectTo(0, cnFilter, 0);
+ cnFilter.connectTo(0, cnFilterTwo, 0);
+
+ ctParser.connectTo(0, ctFilter, 0);
+ kParser.connectTo(0, kFilter, 0);
+ ltParser.connectTo(0, ltFilter, 0);
+ mcParser.connectTo(0, mcFilter, 0);
+ miParser.connectTo(0, miFilter, 0);
+ tParser.connectTo(0, tFilter, 0);
+
+ ltFilter.connectTo(0, ltMlJoin, 0);
+ mlParser.connectTo(0, ltMlJoin, 1);
+
+ ltMlJoin.connectTo(0, ltMlTJoin, 0);
+ tFilter.connectTo(0, ltMlTJoin, 1);
+
+ ltMlTJoin.connectTo(0, ltMlTMkJoin, 0);
+ mkParser.connectTo(0, ltMlTMkJoin, 1);
+
+ ltMlTMkJoin.connectTo(0, ltMlTMkKJoin, 0);
+ kFilter.connectTo(0, ltMlTMkKJoin, 1);
+
+ ltMlTMkKJoin.connectTo(0, ltMlTMkKMcJoin, 0);
+ mcFilter.connectTo(0, ltMlTMkKMcJoin, 1);
+
+ ltMlTMkKMcJoin.connectTo(0, ltMlTMkKMcCtJoin, 0);
+ ctFilter.connectTo(0, ltMlTMkKMcCtJoin, 1);
+
+ ltMlTMkKMcCtJoin.connectTo(0, ltMlTMkKMcCtCnJoin, 0);
+ cnFilterTwo.connectTo(0, ltMlTMkKMcCtCnJoin, 1);
+
+ ltMlTMkKMcCtCnJoin.connectTo(0, ltMlTMkKMcCtCnMiJoin, 0);
+ miFilter.connectTo(0, ltMlTMkKMcCtCnMiJoin, 1);
+
+ ltMlTMkKMcCtCnMiJoin.connectTo(0, cnSelfJoin, 0);
+ cnFilterThree.connectTo(0, cnSelfJoin, 1);
+
+ cnSelfJoin.connectTo(0, cnMin, 0);
+ cnMin.connectTo(0, ltMin, 0);
+ ltMin.connectTo(0, tMin, 0);
+ tMin.connectTo(0, sink, 0);
+
+ return new WayangPlan(sink);
+ }
+
+ private static class LtMlTMkKMcCt extends Tuple2<
+ Tuple2<
+ Tuple2<
+ Tuple2<
+ Tuple2<
+ Tuple2,
+ Title>,
+ MovieKeyword>,
+ Keyword>,
+ MovieCompanies>,
+ CompanyType> {}
+
+ private static class LtMlTMkKMcCtCnMi extends Tuple2<
+ Tuple2,
+ MovieInfo> {}
+}
+
diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query4.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query4.java
new file mode 100644
index 000000000..21407502e
--- /dev/null
+++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/job/complex/Query4.java
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.
+ */
+
+package org.apache.wayang.ml.benchmarks.job.complex;
+
+import org.apache.wayang.apps.imdb.data.*;
+import org.apache.wayang.basic.data.Tuple2;
+import org.apache.wayang.basic.operators.*;
+import org.apache.wayang.core.plan.wayangplan.WayangPlan;
+import org.apache.wayang.core.types.DataSetType;
+import org.apache.wayang.core.util.ReflectionUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+/**
+ * Query4:
+ *
+ * SELECT MIN(cn.name) AS company_name,
+ * MIN(lt.link) AS link_type,
+ * MIN(t.title) AS western_follow_up
+ * FROM company_name AS cn,
+ * company_type AS ct,
+ * keyword AS k,
+ * link_type AS lt,
+ * movie_companies AS mc,
+ * movie_info AS mi,
+ * movie_keyword AS mk,
+ * movie_link AS ml,
+ * title AS t
+ * WHERE cn.country_code !='[pl]'
+ * AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%')
+ * AND ct.kind ='production companies'
+ * AND k.keyword ='sequel'
+ * AND lt.link LIKE '%follow%'
+ * AND mc.note IS NULL
+ * AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark',
+ * 'Swedish', 'Denish', 'Norwegian', 'German')
+ * AND t.production_year BETWEEN 1950 AND 2000
+ * AND lt.id = ml.link_type_id
+ * AND ml.movie_id = t.id
+ * AND t.id = mk.movie_id
+ * AND mk.keyword_id = k.id
+ * AND t.id = mc.movie_id
+ * AND mc.company_type_id = ct.id
+ * AND mc.company_id = cn.id
+ * AND mi.movie_id = t.id
+ * AND ml.movie_id = mk.movie_id
+ * AND ml.movie_id = mc.movie_id
+ * AND mk.movie_id = mc.movie_id
+ * AND ml.movie_id = mi.movie_id
+ * AND mk.movie_id = mi.movie_id
+ * AND mc.movie_id = mi.movie_id
+ * AND cn.name_pcode_nf = cn.name_pcode_sf;
+ */
+public class Query4 {
+
+ public static WayangPlan getWayangPlan(String dataPath, Collection collector){
+
+ TextFileSource companyNameText = new TextFileSource(dataPath + "company_name.csv", "UTF-8");
+ TextFileSource companyTypeText = new TextFileSource(dataPath + "company_type.csv", "UTF-8");
+ TextFileSource keywordText = new TextFileSource(dataPath + "keyword.csv", "UTF-8");
+ TextFileSource linkTypeText = new TextFileSource(dataPath + "link_type.csv", "UTF-8");
+ TextFileSource movieCompaniesText = new TextFileSource(dataPath + "movie_companies.csv", "UTF-8");
+ TextFileSource movieInfoText = new TextFileSource(dataPath + "movie_info.csv", "UTF-8");
+ TextFileSource movieKeywordText = new TextFileSource(dataPath + "movie_keyword.csv", "UTF-8");
+ TextFileSource movieLinkText = new TextFileSource(dataPath + "movie_link.csv", "UTF-8");
+ TextFileSource titleText = new TextFileSource(dataPath + "title.csv", "UTF-8");
+
+ MapOperator cnParser =
+ new MapOperator<>(CompanyName::parseCsv, String.class, CompanyName.class);
+ MapOperator ctParser =
+ new MapOperator<>(CompanyType::parseCsv, String.class, CompanyType.class);
+ MapOperator kParser =
+ new MapOperator<>(Keyword::parseCsv, String.class, Keyword.class);
+ MapOperator ltParser =
+ new MapOperator<>(LinkType::parseCsv, String.class, LinkType.class);
+ MapOperator mcParser =
+ new MapOperator<>(MovieCompanies::parseCsv, String.class, MovieCompanies.class);
+ MapOperator miParser =
+ new MapOperator<>(MovieInfo::parseCsv, String.class, MovieInfo.class);
+ MapOperator mkParser =
+ new MapOperator<>(MovieKeyword::parseCsv, String.class, MovieKeyword.class);
+ MapOperator mlParser =
+ new MapOperator<>(MovieLink::parseCsv, String.class, MovieLink.class);
+ MapOperator tParser =
+ new MapOperator<>(Title::parseCsv, String.class, Title.class);
+
+ FilterOperator cnFilter =
+ new FilterOperator<>(cn -> cn.countryCode() != "[pl]", CompanyName.class);
+ FilterOperator cnFilterTwo =
+ new FilterOperator<>(cn -> cn.name().contains("Film") || cn.name().contains("Warner"), CompanyName.class);
+ FilterOperator cnFilterThree = new FilterOperator<>(cnFilterTwo);
+
+ FilterOperator ctFilter =
+ new FilterOperator<>(ct -> ct.kind().equals("production companies"), CompanyType.class);
+
+ FilterOperator kFilter =
+ new FilterOperator<>(k -> k.keyword().equals("sequel"), Keyword.class);
+
+ FilterOperator ltFilter =
+ new FilterOperator<>(lt -> lt.link().contains("follow"), LinkType.class);
+
+ FilterOperator mcFilter =
+ new FilterOperator<>(mc -> mc.note() == null, MovieCompanies.class);
+
+ FilterOperator miFilter =
+ new FilterOperator<>(mi ->
+ Arrays.asList("Sweden","Norway","Germany","Denmark",
+ "Swedish","Denish","Norwegian","German")
+ .contains(mi.info()),
+ MovieInfo.class);
+
+ FilterOperator tFilter =
+ new FilterOperator<>(t -> t.productionYear() >= 1950 && t.productionYear() <= 2000, Title.class);
+
+ JoinOperator ltMlJoin =
+ new JoinOperator<>(LinkType::id, MovieLink::linkTypeId,
+ LinkType.class, MovieLink.class, Integer.class);
+
+ JoinOperator, Title, Integer> ltMlTJoin =
+ new JoinOperator<>(ltMl -> ltMl.field1.movieId(),
+ Title::id,
+ ReflectionUtils.specify(Tuple2.class),
+ Title.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword, Integer> ltMlTMkJoin =
+ new JoinOperator<>(ltMlT -> ltMlT.field1.id(),
+ MovieKeyword::movieId,
+ ReflectionUtils.specify(Tuple2.class),
+ MovieKeyword.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword>, Keyword, Integer> ltMlTMkKJoin =
+ new JoinOperator<>(ltMlTMk -> ltMlTMk.field1.keywordId(),
+ Keyword::id,
+ ReflectionUtils.specify(Tuple2.class),
+ Keyword.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword>, Keyword>, MovieCompanies, Integer> ltMlTMkKMcJoin =
+ new JoinOperator<>(ltMlTMkK -> ltMlTMkK.field0.field1.id(),
+ MovieCompanies::movieId,
+ ReflectionUtils.specify(Tuple2.class),
+ MovieCompanies.class,
+ Integer.class);
+
+ JoinOperator, Title>, MovieKeyword>, Keyword>, MovieCompanies>, CompanyType, Integer> ltMlTMkKMcCtJoin =
+ new JoinOperator<>(ltMlTMkKMc -> ltMlTMkKMc.field1.companyTypeId(),
+ CompanyType::id,
+ ReflectionUtils.specify(Tuple2.class),
+ CompanyType.class,
+ Integer.class);
+
+ JoinOperator ltMlTMkKMcCtCnJoin =
+ new JoinOperator<>(ltMlTMkKMcCt -> ltMlTMkKMcCt.field0.field1.companyId(),
+ CompanyName::id,
+ ReflectionUtils.specify(LtMlTMkKMcCt.class),
+ CompanyName.class,
+ Integer.class);
+
+ JoinOperator