diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraph.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraph.java index 3e06c95f7..f56c6080d 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraph.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraph.java @@ -47,15 +47,16 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import java.util.function.ToDoubleFunction; import java.util.stream.Collectors; +import java.util.stream.StreamSupport; /** * This graph contains a set of {@link ChannelConversion}s. @@ -168,7 +169,7 @@ private Tree mergeTrees(Collection trees) { final Tree firstTree = iterator.next(); Bitmask combinationSettledIndices = new Bitmask(firstTree.settledDestinationIndices); int maxSettledIndices = combinationSettledIndices.cardinality(); - final HashSet employedChannelDescriptors = new HashSet<>(firstTree.employedChannelDescriptors); + final LinkedHashSet employedChannelDescriptors = new LinkedHashSet<>(firstTree.employedChannelDescriptors); int maxVisitedChannelDescriptors = employedChannelDescriptors.size(); double costs = firstTree.costs; TreeVertex newRoot = new TreeVertex(firstTree.root.channelDescriptor, firstTree.root.settledIndices); @@ -222,7 +223,11 @@ public static class CostbasedTreeSelectionStrategy implements TreeSelectionStrat @Override public Tree select(Tree t1, Tree t2) { - return t1.costs <= t2.costs ? t1 : t2; + int cmp = Double.compare(t1.costs, t2.costs); + if (cmp == 0) { + cmp = t1.getDeterministicSignature().compareTo(t2.getDeterministicSignature()); + } + return cmp <= 0 ? t1 : t2; } } @@ -381,7 +386,7 @@ private ShortestTreeSearcher(OutputSlot sourceOutput, this.existingDestinationChannelIndices = new Bitmask(); this.collectExistingChannels(sourceChannel); - this.openChannelDescriptors = new HashSet<>(openChannels.size()); + this.openChannelDescriptors = new LinkedHashSet<>(openChannels.size()); for (Channel openChannel : openChannels) { this.openChannelDescriptors.add(openChannel.getDescriptor()); } @@ -477,7 +482,9 @@ private Set resolveSupportedChannels(final InputSlot input final List supportedInputChannels = owner.getSupportedInputChannels(input.getIndex()); if (input.isLoopInvariant()) { // Loop input is needed in several iterations and must therefore be reusable. - return supportedInputChannels.stream().filter(ChannelDescriptor::isReusable).collect(Collectors.toSet()); + return supportedInputChannels.stream() + .filter(ChannelDescriptor::isReusable) + .collect(Collectors.toCollection(LinkedHashSet::new)); } else { return WayangCollections.asSet(supportedInputChannels); } @@ -546,7 +553,7 @@ private void kernelizeChannelRequests() { } if (channelDescriptors.size() - numReusableChannels == 1) { iterator.remove(); - channelDescriptors = new HashSet<>(channelDescriptors); + channelDescriptors = new LinkedHashSet<>(channelDescriptors); channelDescriptors.removeIf(channelDescriptor -> !channelDescriptor.isReusable()); kernelDestChannelDescriptorSetsToIndicesUpdates.add(new Tuple<>(channelDescriptors, indices)); } @@ -575,7 +582,7 @@ private void kernelizeChannelRequests() { */ private Tree searchTree() { // Prepare the recursive traversal. - final HashSet visitedChannelDescriptors = new HashSet<>(16); + final LinkedHashSet visitedChannelDescriptors = new LinkedHashSet<>(16); visitedChannelDescriptors.add(this.sourceChannelDescriptor); // Perform the traversal. @@ -777,7 +784,7 @@ private Set getSuccessorChannelDescriptors(ChannelDescriptor final Channel channel = this.existingChannels.get(descriptor); if (channel == null || this.openChannelDescriptors.contains(descriptor)) return null; - Set result = new HashSet<>(); + Set result = new LinkedHashSet<>(); for (ExecutionTask consumer : channel.getConsumers()) { if (!consumer.getOperator().isAuxiliary()) continue; for (Channel successorChannel : consumer.getOutputChannels()) { @@ -988,7 +995,12 @@ private static class Tree { * * @see TreeVertex#channelDescriptor */ - private final Set employedChannelDescriptors = new HashSet<>(); + private final Set employedChannelDescriptors = new LinkedHashSet<>(); + + /** + * Cached deterministic signature for tie-breaking. + */ + private String deterministicSignature; /** * The sum of the costs of all {@link TreeEdge}s of this instance. @@ -1010,6 +1022,7 @@ static Tree singleton(ChannelDescriptor channelDescriptor, Bitmask settledIndice this.root = root; this.settledDestinationIndices = settledDestinationIndices; this.employedChannelDescriptors.add(root.channelDescriptor); + this.deterministicSignature = null; } /** @@ -1033,6 +1046,21 @@ void reroot(ChannelDescriptor newRootChannelDescriptor, this.employedChannelDescriptors.add(newRootChannelDescriptor); this.settledDestinationIndices.orInPlace(newRootSettledIndices); this.costs += edge.costEstimate; + this.deterministicSignature = null; + } + + private String getDeterministicSignature() { + if (this.deterministicSignature == null) { + final String descriptorSignature = this.employedChannelDescriptors.stream() + .map(Object::toString) + .sorted() + .collect(Collectors.joining("|")); + final String indexSignature = StreamSupport.stream(this.settledDestinationIndices.spliterator(), false) + .map(String::valueOf) + .collect(Collectors.joining(",")); + this.deterministicSignature = descriptorSignature + "#" + indexSignature; + } + return this.deterministicSignature; } @Override @@ -1090,7 +1118,7 @@ private void copyEdgesFrom(TreeVertex that) { * @return a {@link Set} of said {@link ChannelConversion}s */ private Set getChildChannelConversions() { - Set channelConversions = new HashSet<>(); + Set channelConversions = new LinkedHashSet<>(); for (TreeEdge edge : this.outEdges) { channelConversions.add(edge.channelConversion); channelConversions.addAll(edge.destination.getChildChannelConversions()); diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/costs/DefaultEstimatableCost.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/costs/DefaultEstimatableCost.java index ebc0f8cd2..2981d4b2a 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/costs/DefaultEstimatableCost.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/costs/DefaultEstimatableCost.java @@ -66,11 +66,7 @@ public class DefaultEstimatableCost implements EstimatableCost { Set executedStages ) { final PlanImplementation bestPlanImplementation = executionPlans.stream() - .reduce((p1, p2) -> { - final double t1 = p1.getSquashedCostEstimate(); - final double t2 = p2.getSquashedCostEstimate(); - return t1 < t2 ? p1 : p2; - }) + .min(PlanImplementation.costComparator()) .orElseThrow(() -> new WayangException("Could not find an execution plan.")); return bestPlanImplementation; } diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/LatentOperatorPruningStrategy.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/LatentOperatorPruningStrategy.java index a228f148f..fa797e900 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/LatentOperatorPruningStrategy.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/LatentOperatorPruningStrategy.java @@ -85,7 +85,7 @@ private PlanImplementation selectBestPlanBinary(PlanImplementation p1, PlanImplementation p2) { final double t1 = p1.getSquashedCostEstimate(true); final double t2 = p2.getSquashedCostEstimate(true); - final boolean isPickP1 = t1 <= t2; + final boolean isPickP1 = PlanImplementation.costComparator().compare(p1, p2) <= 0; if (logger.isDebugEnabled()) { if (isPickP1) { LogManager.getLogger(LatentOperatorPruningStrategy.class).debug( diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java index e00753c37..bec527383 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java @@ -42,7 +42,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -91,7 +91,7 @@ public class PlanEnumeration { * Creates a new instance. */ public PlanEnumeration() { - this(new HashSet<>(), new HashSet<>(), new HashSet<>()); + this(new LinkedHashSet<>(), new LinkedHashSet<>(), new LinkedHashSet<>()); } /** @@ -412,16 +412,15 @@ private Collection concatenatePartialPlansBatchwise( if (junction == null) continue; // If we found a junction, then we can enumerate all PlanImplementation combinations. - final List> groupPlans = WayangCollections.map( - concatGroupCombo, - concatGroup -> { - Set concatDescriptors = concatGroup2concatDescriptor.get(concatGroup); - Set planImplementations = new HashSet<>(concatDescriptors.size()); - for (PlanImplementation.ConcatenationDescriptor concatDescriptor : concatDescriptors) { - planImplementations.add(concatDescriptor.getPlanImplementation()); - } - return planImplementations; - }); + final List> groupPlans = WayangCollections.map( + concatGroupCombo, + concatGroup -> { + Set concatDescriptors = concatGroup2concatDescriptor.get(concatGroup); + return concatDescriptors.stream() + .map(PlanImplementation.ConcatenationDescriptor::getPlanImplementation) + .sorted(PlanImplementation.structuralComparator()) + .collect(Collectors.toList()); + }); for (List planCombo : WayangCollections.streamedCrossProduct(groupPlans)) { PlanImplementation basePlan = planCombo.get(0); diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java index b66056af9..59103bab0 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java @@ -47,9 +47,11 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -57,6 +59,7 @@ import java.util.Set; import java.util.function.ToDoubleFunction; import java.util.stream.Collectors; +import java.util.stream.IntStream; import java.util.stream.Stream; /** @@ -65,6 +68,11 @@ public class PlanImplementation { private static final Logger logger = LogManager.getLogger(PlanImplementation.class); + private static final Comparator COST_COMPARATOR = + Comparator.comparingDouble((PlanImplementation plan) -> plan.getSquashedCostEstimate(true)) + .thenComparing(PlanImplementation::getDeterministicIdentifier); + private static final Comparator STRUCTURAL_COMPARATOR = + Comparator.comparing(PlanImplementation::getDeterministicIdentifier); /** * {@link ExecutionOperator}s contained in this instance. @@ -180,6 +188,14 @@ private PlanImplementation(PlanEnumeration planEnumeration, assert this.planEnumeration != null; } + public static Comparator costComparator() { + return COST_COMPARATOR; + } + + public static Comparator structuralComparator() { + return STRUCTURAL_COMPARATOR; + } + /** * @return the {@link PlanEnumeration} this instance belongs to @@ -238,7 +254,7 @@ Collection> findExecutionOperatorInputs(final InputSlot someInpu // Discern LoopHeadOperator InputSlots and loop body InputSlots. final List iterationImpls = loopImplementation.getIterationImplementations(); - final Collection> collector = new HashSet<>(innerInputs.size()); + final Collection> collector = new LinkedHashSet<>(innerInputs.size()); for (InputSlot innerInput : innerInputs) { if (innerInput.getOwner() == loopSubplan.getLoopHead()) { final LoopImplementation.IterationImplementation initialIterationImpl = iterationImpls.get(0); @@ -312,7 +328,7 @@ Collection, PlanImplementation>> findExecutionOperatorOutput // For all the iterations, return the potential OutputSlots. final List iterationImpls = loopImplementation.getIterationImplementations(); - final Set, PlanImplementation>> collector = new HashSet<>(iterationImpls.size()); + final Set, PlanImplementation>> collector = new LinkedHashSet<>(iterationImpls.size()); for (LoopImplementation.IterationImplementation iterationImpl : iterationImpls) { final Collection, PlanImplementation>> outputsWithContext = iterationImpl.getBodyImplementation().findExecutionOperatorOutputWithContext(innerOutput); @@ -678,8 +694,8 @@ public double getSquashedCostEstimate() { private Tuple, List> getParallelOperatorJunctionAllCostEstimate(Operator operator) { - Set inputOperators = new HashSet<>(); - Set inputJunction = new HashSet<>(); + Set inputOperators = new LinkedHashSet<>(); + Set inputJunction = new LinkedHashSet<>(); List probalisticCost = new ArrayList<>(); List squashedCost = new ArrayList<>(); @@ -976,6 +992,67 @@ Stream streamOperators() { return operatorStream; } + /** + * Provides a deterministic identifier that captures the current state of this plan. While not guaranteed to + * be unique, it is stable across runs for the same logical plan and can therefore be used for reproducible + * ordering. + * + * @return the deterministic identifier + */ + public String getDeterministicIdentifier() { + final String operatorDescriptor = this.operators.stream() + .map(PlanImplementation::describeOperator) + .sorted() + .collect(Collectors.joining("|")); + final String junctionDescriptor = this.junctions.values().stream() + .map(PlanImplementation::describeJunction) + .sorted() + .collect(Collectors.joining("|")); + final String loopDescriptor = this.loopImplementations.entrySet().stream() + .map(entry -> describeLoop(entry.getKey(), entry.getValue())) + .sorted() + .collect(Collectors.joining("|")); + return operatorDescriptor + "#" + junctionDescriptor + "#" + loopDescriptor; + } + + private static String describeOperator(Operator operator) { + final String name = operator.getName() == null ? "" : operator.getName(); + return operator.getClass().getName() + ":" + name + ":" + operator.getEpoch(); + } + + private static String describeJunction(Junction junction) { + final String source = describeOutputSlot(junction.getSourceOutput()); + final String targets = IntStream.range(0, junction.getNumTargets()) + .mapToObj(i -> describeInputSlot(junction.getTargetInput(i))) + .sorted() + .collect(Collectors.joining(",")); + return source + "->" + targets; + } + + private static String describeLoop(LoopSubplan loop, LoopImplementation implementation) { + final String descriptor = describeOperator(loop); + final String iterationDescriptor = implementation.getIterationImplementations().stream() + .map(iteration -> Integer.toString(iteration.getNumIterations())) + .collect(Collectors.joining(",")); + return descriptor + ":" + iterationDescriptor; + } + + private static String describeInputSlot(InputSlot slot) { + if (slot == null) { + return "null"; + } + final Operator owner = slot.getOwner(); + return describeOperator(owner) + ".in[" + slot.getIndex() + "]:" + slot.getName(); + } + + private static String describeOutputSlot(OutputSlot slot) { + if (slot == null) { + return "null"; + } + final Operator owner = slot.getOwner(); + return describeOperator(owner) + ".out[" + slot.getIndex() + "]:" + slot.getName(); + } + @Override public String toString() { return String.format("PlanImplementation[%s, %s, costs=%s]", diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/TopKPruningStrategy.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/TopKPruningStrategy.java index a80f9d99c..2519bc13c 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/TopKPruningStrategy.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/TopKPruningStrategy.java @@ -40,16 +40,8 @@ public void prune(PlanEnumeration planEnumeration) { if (planEnumeration.getPlanImplementations().size() <= this.k) return; ArrayList planImplementations = new ArrayList<>(planEnumeration.getPlanImplementations()); - planImplementations.sort(this::comparePlanImplementations); + planImplementations.sort(PlanImplementation.costComparator()); planEnumeration.getPlanImplementations().retainAll(planImplementations.subList(0, this.k)); } - - private int comparePlanImplementations(PlanImplementation p1, - PlanImplementation p2) { - final double t1 = p1.getSquashedCostEstimate(true); - final double t2 = p2.getSquashedCostEstimate(true); - return Double.compare(t1, t2); - } - } diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/MultiMap.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/MultiMap.java index 574d4ff44..4fa6bf3a3 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/MultiMap.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/MultiMap.java @@ -18,14 +18,14 @@ package org.apache.wayang.core.util; -import java.util.HashMap; -import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.Set; /** * Maps keys to multiple values. Each key value pair is unique. */ -public class MultiMap extends HashMap> { +public class MultiMap extends LinkedHashMap> { /** * Associate a key with a new value. @@ -35,7 +35,7 @@ public class MultiMap extends HashMap> { * @return whether the value was not yet associated with the key */ public boolean putSingle(K key, V value) { - final Set values = this.computeIfAbsent(key, k -> new HashSet<>()); + final Set values = this.computeIfAbsent(key, k -> new LinkedHashSet<>()); return values.add(value); } diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/WayangCollections.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/WayangCollections.java index 5b3e8918f..b1eb6d787 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/WayangCollections.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/WayangCollections.java @@ -24,7 +24,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.List; @@ -59,7 +59,7 @@ public static Set asSet(Collection collection) { if (collection instanceof Set) { return (Set) collection; } - return new HashSet<>(collection); + return new LinkedHashSet<>(collection); } /** @@ -69,7 +69,7 @@ public static Set asSet(Iterable iterable) { if (iterable instanceof Set) { return (Set) iterable; } - Set set = new HashSet<>(); + Set set = new LinkedHashSet<>(); for (T t : iterable) { set.add(t); } @@ -80,7 +80,7 @@ public static Set asSet(Iterable iterable) { * Provides the given {@code values} as {@link Set}. */ public static Set asSet(T... values) { - Set set = new HashSet<>(values.length); + Set set = new LinkedHashSet<>(values.length); for (T value : values) { set.add(value); } diff --git a/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraphDeterminismTest.java b/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraphDeterminismTest.java new file mode 100644 index 000000000..7bdab816f --- /dev/null +++ b/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraphDeterminismTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.core.optimizer.channels; + +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.DefaultOptimizationContext; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.optimizer.OptimizationUtils; +import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimate; +import org.apache.wayang.core.plan.executionplan.Channel; +import org.apache.wayang.core.plan.executionplan.ExecutionTask; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.plan.wayangplan.InputSlot; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.Junction; +import org.apache.wayang.core.test.DummyExecutionOperator; +import org.apache.wayang.core.test.DummyExternalReusableChannel; +import org.apache.wayang.core.test.DummyNonReusableChannel; +import org.apache.wayang.core.test.DummyReusableChannel; +import org.apache.wayang.core.test.MockFactory; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class ChannelConversionGraphDeterminismTest { + + private static Supplier createDummyExecutionOperatorFactory(ChannelDescriptor channelDescriptor) { + return () -> { + ExecutionOperator execOp = new DummyExecutionOperator(1, 1, false); + execOp.getSupportedOutputChannels(0).add(channelDescriptor); + return execOp; + }; + } + + private static DefaultChannelConversion conversion(ChannelDescriptor source, ChannelDescriptor target) { + return new DefaultChannelConversion(source, target, createDummyExecutionOperatorFactory(target)); + } + + @Test + void channelConversionSelectionIsStable() { + List first = computeJunctionFingerprint(); + List second = computeJunctionFingerprint(); + assertEquals(first, second, "Channel conversion choices must be deterministic."); + } + + private static List computeJunctionFingerprint() { + Configuration configuration = new Configuration(); + ChannelConversionGraph graph = new ChannelConversionGraph(configuration); + graph.add(conversion(DummyReusableChannel.DESCRIPTOR, DummyNonReusableChannel.DESCRIPTOR)); + graph.add(conversion(DummyReusableChannel.DESCRIPTOR, DummyExternalReusableChannel.DESCRIPTOR)); + graph.add(conversion(DummyExternalReusableChannel.DESCRIPTOR, DummyNonReusableChannel.DESCRIPTOR)); + graph.add(conversion(DummyNonReusableChannel.DESCRIPTOR, DummyReusableChannel.DESCRIPTOR)); + + Job job = MockFactory.createJob(configuration); + OptimizationContext optimizationContext = new DefaultOptimizationContext(job); + + DummyExecutionOperator sourceOperator = new DummyExecutionOperator(0, 1, false); + sourceOperator.getSupportedOutputChannels(0).add(DummyReusableChannel.DESCRIPTOR); + optimizationContext.addOneTimeOperator(sourceOperator) + .setOutputCardinality(0, new CardinalityEstimate(1000, 1000, 1d)); + + DummyExecutionOperator destOperator0 = new DummyExecutionOperator(1, 1, false); + destOperator0.getSupportedInputChannels(0).add(DummyNonReusableChannel.DESCRIPTOR); + + DummyExecutionOperator destOperator1 = new DummyExecutionOperator(1, 1, false); + destOperator1.getSupportedInputChannels(0).add(DummyExternalReusableChannel.DESCRIPTOR); + + Junction junction = graph.findMinimumCostJunction( + sourceOperator.getOutput(0), + Arrays.asList(destOperator0.getInput(0), destOperator1.getInput(0)), + optimizationContext, + false + ); + + return describeJunction(junction); + } + + private static List describeJunction(Junction junction) { + List descriptorList = new ArrayList<>(); + descriptorList.add(describeChannel(junction.getSourceChannel(), true)); + for (int i = 0; i < junction.getNumTargets(); i++) { + descriptorList.add(describeChannel(junction.getTargetChannel(i), false)); + } + return descriptorList; + } + + private static String describeChannel(Channel channel, boolean isSourceChannel) { + if (channel == null) { + return "null"; + } + List descriptors = new ArrayList<>(); + Channel cursor = channel; + while (cursor != null) { + descriptors.add(cursor.getDescriptor().toString() + (cursor.isCopy() ? ":copy" : ":orig")); + ExecutionTask producer = cursor.getProducer(); + if (producer == null || producer.getNumInputChannels() == 0) { + break; + } + // If we are describing the top-level source channel (junction entry), stop once we reach the producer that + // has no inputs. For target channels, follow until the conversion tree ends. + if (isSourceChannel) { + cursor = producer.getNumInputChannels() == 0 ? null : producer.getInputChannel(0); + } else if (producer.getNumInputChannels() == 0) { + cursor = null; + } else { + cursor = producer.getInputChannel(0); + } + } + Collections.reverse(descriptors); + return descriptors.stream().collect(Collectors.joining("->")); + } +} diff --git a/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java b/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java new file mode 100644 index 000000000..f72525740 --- /dev/null +++ b/wayang-commons/wayang-core/src/test/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumerationDeterminismTest.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.core.optimizer.enumeration; + +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.api.configuration.ExplicitCollectionProvider; +import org.apache.wayang.core.optimizer.DefaultOptimizationContext; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.optimizer.costs.ConstantLoadProfileEstimator; +import org.apache.wayang.core.optimizer.costs.LoadEstimate; +import org.apache.wayang.core.optimizer.costs.LoadProfile; +import org.apache.wayang.core.plan.executionplan.Channel; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.plan.wayangplan.InputSlot; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.test.DummyExecutionOperator; +import org.apache.wayang.core.test.DummyReusableChannel; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Integration test that exercises {@link PlanEnumeration#concatenate(OutputSlot, Collection, Map, OptimizationContext, org.apache.wayang.commons.util.profiledb.model.measurement.TimeMeasurement)} + * to ensure that plan combinations are produced deterministically. + */ +class PlanEnumerationDeterminismTest { + + @Test + void concatenationProducesStablePlanOrdering() { + Configuration configuration = new Configuration(); + configuration.setPruningStrategyClassProvider( + new ExplicitCollectionProvider>(configuration) + ); + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + DummyExecutionOperator producer = new DummyExecutionOperator(0, 1, false); + DummyExecutionOperator consumer = new DummyExecutionOperator(1, 0, false); + registerChannelDescriptors(producer, Collections.singletonList(consumer)); + registerLoadEstimator(configuration, producer, 10); + registerLoadEstimator(configuration, consumer, 5); + + List firstRun = enumerateDeterministicIds(job, producer, Collections.singletonList(consumer), 3, 2); + List secondRun = enumerateDeterministicIds(job, producer, Collections.singletonList(consumer), 3, 2); + + assertTrue(firstRun.size() > 1, "Expected multiple plan implementations."); + assertEquals(firstRun, secondRun, "Enumeration order must be deterministic."); + } + + @Test + void concatenationWithMultipleTargetsRemainsStable() { + Configuration configuration = new Configuration(); + configuration.setPruningStrategyClassProvider( + new ExplicitCollectionProvider>(configuration) + ); + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + DummyExecutionOperator producer = new DummyExecutionOperator(0, 1, false); + DummyExecutionOperator consumerA = new DummyExecutionOperator(1, 0, false); + DummyExecutionOperator consumerB = new DummyExecutionOperator(1, 0, false); + registerChannelDescriptors(producer, Arrays.asList(consumerA, consumerB)); + registerLoadEstimator(configuration, producer, 10); + registerLoadEstimator(configuration, consumerA, 7); + registerLoadEstimator(configuration, consumerB, 3); + + List firstRun = enumerateDeterministicIds(job, producer, Arrays.asList(consumerA, consumerB), 4, 2); + List secondRun = enumerateDeterministicIds(job, producer, Arrays.asList(consumerA, consumerB), 4, 2); + + assertEquals(firstRun, secondRun, "Enumeration order with multiple targets must be deterministic."); + } + + private static List enumerateDeterministicIds(Job job, + ExecutionOperator producer, + List consumers, + int numBaseCopies, + int numTargetCopies) { + DefaultOptimizationContext optimizationContext = new DefaultOptimizationContext(job); + optimizationContext.addOneTimeOperator(producer); + consumers.forEach(optimizationContext::addOneTimeOperator); + + PlanEnumeration baseEnumeration = PlanEnumeration.createSingleton(producer, optimizationContext); + duplicatePlanImplementations(baseEnumeration, numBaseCopies); + + Map, PlanEnumeration> targets = new LinkedHashMap<>(); + consumers.forEach(consumer -> { + PlanEnumeration targetEnumeration = PlanEnumeration.createSingleton((ExecutionOperator) consumer, optimizationContext); + duplicatePlanImplementations(targetEnumeration, numTargetCopies); + targets.put(consumer.getInput(0), targetEnumeration); + }); + + PlanEnumeration concatenated = baseEnumeration.concatenate( + producer.getOutput(0), + Collections.emptyList(), + targets, + optimizationContext, + null + ); + + return concatenated.getPlanImplementations().stream() + .map(PlanImplementation::getDeterministicIdentifier) + .collect(Collectors.toList()); + } + + private static void duplicatePlanImplementations(PlanEnumeration enumeration, int desiredCount) { + PlanImplementation template = enumeration.getPlanImplementations().iterator().next(); + while (enumeration.getPlanImplementations().size() < desiredCount) { + enumeration.add(new PlanImplementation(template)); + } + } + + private static void registerLoadEstimator(Configuration configuration, + ExecutionOperator operator, + long cpuCost) { + ConstantLoadProfileEstimator estimator = new ConstantLoadProfileEstimator( + new LoadProfile(new LoadEstimate(cpuCost), new LoadEstimate(1)) + ); + configuration.getOperatorLoadProfileEstimatorProvider().set(operator, estimator); + } + + private static void registerChannelDescriptors(DummyExecutionOperator producer, + List consumers) { + producer.getSupportedOutputChannels(0).add(DummyReusableChannel.DESCRIPTOR); + consumers.forEach(consumer -> consumer.getSupportedInputChannels(0).add(DummyReusableChannel.DESCRIPTOR)); + } +}