Skip to content

Commit 198ff8e

Browse files
authored
Merge pull request #662 from novatechflow/pr/deterministic-plans
Ensure deterministic plan enumeration and channel conversion
2 parents e5527cd + e7d850e commit 198ff8e

10 files changed

Lines changed: 431 additions & 49 deletions

File tree

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/channels/ChannelConversionGraph.java

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,16 @@
4747
import java.util.Collections;
4848
import java.util.Comparator;
4949
import java.util.HashMap;
50-
import java.util.HashSet;
5150
import java.util.Iterator;
5251
import java.util.LinkedList;
52+
import java.util.LinkedHashSet;
5353
import java.util.List;
5454
import java.util.Map;
5555
import java.util.Random;
5656
import java.util.Set;
5757
import java.util.function.ToDoubleFunction;
5858
import java.util.stream.Collectors;
59+
import java.util.stream.StreamSupport;
5960

6061
/**
6162
* This graph contains a set of {@link ChannelConversion}s.
@@ -168,7 +169,7 @@ private Tree mergeTrees(Collection<Tree> trees) {
168169
final Tree firstTree = iterator.next();
169170
Bitmask combinationSettledIndices = new Bitmask(firstTree.settledDestinationIndices);
170171
int maxSettledIndices = combinationSettledIndices.cardinality();
171-
final HashSet<ChannelDescriptor> employedChannelDescriptors = new HashSet<>(firstTree.employedChannelDescriptors);
172+
final LinkedHashSet<ChannelDescriptor> employedChannelDescriptors = new LinkedHashSet<>(firstTree.employedChannelDescriptors);
172173
int maxVisitedChannelDescriptors = employedChannelDescriptors.size();
173174
double costs = firstTree.costs;
174175
TreeVertex newRoot = new TreeVertex(firstTree.root.channelDescriptor, firstTree.root.settledIndices);
@@ -222,7 +223,11 @@ public static class CostbasedTreeSelectionStrategy implements TreeSelectionStrat
222223

223224
@Override
224225
public Tree select(Tree t1, Tree t2) {
225-
return t1.costs <= t2.costs ? t1 : t2;
226+
int cmp = Double.compare(t1.costs, t2.costs);
227+
if (cmp == 0) {
228+
cmp = t1.getDeterministicSignature().compareTo(t2.getDeterministicSignature());
229+
}
230+
return cmp <= 0 ? t1 : t2;
226231
}
227232

228233
}
@@ -381,7 +386,7 @@ private ShortestTreeSearcher(OutputSlot<?> sourceOutput,
381386
this.existingDestinationChannelIndices = new Bitmask();
382387

383388
this.collectExistingChannels(sourceChannel);
384-
this.openChannelDescriptors = new HashSet<>(openChannels.size());
389+
this.openChannelDescriptors = new LinkedHashSet<>(openChannels.size());
385390
for (Channel openChannel : openChannels) {
386391
this.openChannelDescriptors.add(openChannel.getDescriptor());
387392
}
@@ -477,7 +482,9 @@ private Set<ChannelDescriptor> resolveSupportedChannels(final InputSlot<?> input
477482
final List<ChannelDescriptor> supportedInputChannels = owner.getSupportedInputChannels(input.getIndex());
478483
if (input.isLoopInvariant()) {
479484
// Loop input is needed in several iterations and must therefore be reusable.
480-
return supportedInputChannels.stream().filter(ChannelDescriptor::isReusable).collect(Collectors.toSet());
485+
return supportedInputChannels.stream()
486+
.filter(ChannelDescriptor::isReusable)
487+
.collect(Collectors.toCollection(LinkedHashSet::new));
481488
} else {
482489
return WayangCollections.asSet(supportedInputChannels);
483490
}
@@ -546,7 +553,7 @@ private void kernelizeChannelRequests() {
546553
}
547554
if (channelDescriptors.size() - numReusableChannels == 1) {
548555
iterator.remove();
549-
channelDescriptors = new HashSet<>(channelDescriptors);
556+
channelDescriptors = new LinkedHashSet<>(channelDescriptors);
550557
channelDescriptors.removeIf(channelDescriptor -> !channelDescriptor.isReusable());
551558
kernelDestChannelDescriptorSetsToIndicesUpdates.add(new Tuple<>(channelDescriptors, indices));
552559
}
@@ -575,7 +582,7 @@ private void kernelizeChannelRequests() {
575582
*/
576583
private Tree searchTree() {
577584
// Prepare the recursive traversal.
578-
final HashSet<ChannelDescriptor> visitedChannelDescriptors = new HashSet<>(16);
585+
final LinkedHashSet<ChannelDescriptor> visitedChannelDescriptors = new LinkedHashSet<>(16);
579586
visitedChannelDescriptors.add(this.sourceChannelDescriptor);
580587

581588
// Perform the traversal.
@@ -777,7 +784,7 @@ private Set<ChannelDescriptor> getSuccessorChannelDescriptors(ChannelDescriptor
777784
final Channel channel = this.existingChannels.get(descriptor);
778785
if (channel == null || this.openChannelDescriptors.contains(descriptor)) return null;
779786

780-
Set<ChannelDescriptor> result = new HashSet<>();
787+
Set<ChannelDescriptor> result = new LinkedHashSet<>();
781788
for (ExecutionTask consumer : channel.getConsumers()) {
782789
if (!consumer.getOperator().isAuxiliary()) continue;
783790
for (Channel successorChannel : consumer.getOutputChannels()) {
@@ -988,7 +995,12 @@ private static class Tree {
988995
*
989996
* @see TreeVertex#channelDescriptor
990997
*/
991-
private final Set<ChannelDescriptor> employedChannelDescriptors = new HashSet<>();
998+
private final Set<ChannelDescriptor> employedChannelDescriptors = new LinkedHashSet<>();
999+
1000+
/**
1001+
* Cached deterministic signature for tie-breaking.
1002+
*/
1003+
private String deterministicSignature;
9921004

9931005
/**
9941006
* The sum of the costs of all {@link TreeEdge}s of this instance.
@@ -1010,6 +1022,7 @@ static Tree singleton(ChannelDescriptor channelDescriptor, Bitmask settledIndice
10101022
this.root = root;
10111023
this.settledDestinationIndices = settledDestinationIndices;
10121024
this.employedChannelDescriptors.add(root.channelDescriptor);
1025+
this.deterministicSignature = null;
10131026
}
10141027

10151028
/**
@@ -1033,6 +1046,21 @@ void reroot(ChannelDescriptor newRootChannelDescriptor,
10331046
this.employedChannelDescriptors.add(newRootChannelDescriptor);
10341047
this.settledDestinationIndices.orInPlace(newRootSettledIndices);
10351048
this.costs += edge.costEstimate;
1049+
this.deterministicSignature = null;
1050+
}
1051+
1052+
private String getDeterministicSignature() {
1053+
if (this.deterministicSignature == null) {
1054+
final String descriptorSignature = this.employedChannelDescriptors.stream()
1055+
.map(Object::toString)
1056+
.sorted()
1057+
.collect(Collectors.joining("|"));
1058+
final String indexSignature = StreamSupport.stream(this.settledDestinationIndices.spliterator(), false)
1059+
.map(String::valueOf)
1060+
.collect(Collectors.joining(","));
1061+
this.deterministicSignature = descriptorSignature + "#" + indexSignature;
1062+
}
1063+
return this.deterministicSignature;
10361064
}
10371065

10381066
@Override
@@ -1090,7 +1118,7 @@ private void copyEdgesFrom(TreeVertex that) {
10901118
* @return a {@link Set} of said {@link ChannelConversion}s
10911119
*/
10921120
private Set<ChannelConversion> getChildChannelConversions() {
1093-
Set<ChannelConversion> channelConversions = new HashSet<>();
1121+
Set<ChannelConversion> channelConversions = new LinkedHashSet<>();
10941122
for (TreeEdge edge : this.outEdges) {
10951123
channelConversions.add(edge.channelConversion);
10961124
channelConversions.addAll(edge.destination.getChildChannelConversions());

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/costs/DefaultEstimatableCost.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,7 @@ public class DefaultEstimatableCost implements EstimatableCost {
6666
Set<ExecutionStage> executedStages
6767
) {
6868
final PlanImplementation bestPlanImplementation = executionPlans.stream()
69-
.reduce((p1, p2) -> {
70-
final double t1 = p1.getSquashedCostEstimate();
71-
final double t2 = p2.getSquashedCostEstimate();
72-
return t1 < t2 ? p1 : p2;
73-
})
69+
.min(PlanImplementation.costComparator())
7470
.orElseThrow(() -> new WayangException("Could not find an execution plan."));
7571
return bestPlanImplementation;
7672
}

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/LatentOperatorPruningStrategy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ private PlanImplementation selectBestPlanBinary(PlanImplementation p1,
8585
PlanImplementation p2) {
8686
final double t1 = p1.getSquashedCostEstimate(true);
8787
final double t2 = p2.getSquashedCostEstimate(true);
88-
final boolean isPickP1 = t1 <= t2;
88+
final boolean isPickP1 = PlanImplementation.costComparator().compare(p1, p2) <= 0;
8989
if (logger.isDebugEnabled()) {
9090
if (isPickP1) {
9191
LogManager.getLogger(LatentOperatorPruningStrategy.class).debug(

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanEnumeration.java

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import java.util.Collection;
4343
import java.util.Collections;
4444
import java.util.HashMap;
45-
import java.util.HashSet;
45+
import java.util.LinkedHashSet;
4646
import java.util.LinkedList;
4747
import java.util.List;
4848
import java.util.Map;
@@ -91,7 +91,7 @@ public class PlanEnumeration {
9191
* Creates a new instance.
9292
*/
9393
public PlanEnumeration() {
94-
this(new HashSet<>(), new HashSet<>(), new HashSet<>());
94+
this(new LinkedHashSet<>(), new LinkedHashSet<>(), new LinkedHashSet<>());
9595
}
9696

9797
/**
@@ -412,16 +412,15 @@ private Collection<PlanImplementation> concatenatePartialPlansBatchwise(
412412
if (junction == null) continue;
413413

414414
// If we found a junction, then we can enumerate all PlanImplementation combinations.
415-
final List<Set<PlanImplementation>> groupPlans = WayangCollections.map(
416-
concatGroupCombo,
417-
concatGroup -> {
418-
Set<PlanImplementation.ConcatenationDescriptor> concatDescriptors = concatGroup2concatDescriptor.get(concatGroup);
419-
Set<PlanImplementation> planImplementations = new HashSet<>(concatDescriptors.size());
420-
for (PlanImplementation.ConcatenationDescriptor concatDescriptor : concatDescriptors) {
421-
planImplementations.add(concatDescriptor.getPlanImplementation());
422-
}
423-
return planImplementations;
424-
});
415+
final List<List<PlanImplementation>> groupPlans = WayangCollections.map(
416+
concatGroupCombo,
417+
concatGroup -> {
418+
Set<PlanImplementation.ConcatenationDescriptor> concatDescriptors = concatGroup2concatDescriptor.get(concatGroup);
419+
return concatDescriptors.stream()
420+
.map(PlanImplementation.ConcatenationDescriptor::getPlanImplementation)
421+
.sorted(PlanImplementation.structuralComparator())
422+
.collect(Collectors.toList());
423+
});
425424

426425
for (List<PlanImplementation> planCombo : WayangCollections.streamedCrossProduct(groupPlans)) {
427426
PlanImplementation basePlan = planCombo.get(0);

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/PlanImplementation.java

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,19 @@
4747
import java.util.Arrays;
4848
import java.util.Collection;
4949
import java.util.Collections;
50+
import java.util.Comparator;
5051
import java.util.HashMap;
5152
import java.util.HashSet;
5253
import java.util.Iterator;
54+
import java.util.LinkedHashSet;
5355
import java.util.LinkedList;
5456
import java.util.List;
5557
import java.util.Map;
5658
import java.util.Objects;
5759
import java.util.Set;
5860
import java.util.function.ToDoubleFunction;
5961
import java.util.stream.Collectors;
62+
import java.util.stream.IntStream;
6063
import java.util.stream.Stream;
6164

6265
/**
@@ -65,6 +68,11 @@
6568
public class PlanImplementation {
6669

6770
private static final Logger logger = LogManager.getLogger(PlanImplementation.class);
71+
private static final Comparator<PlanImplementation> COST_COMPARATOR =
72+
Comparator.comparingDouble((PlanImplementation plan) -> plan.getSquashedCostEstimate(true))
73+
.thenComparing(PlanImplementation::getDeterministicIdentifier);
74+
private static final Comparator<PlanImplementation> STRUCTURAL_COMPARATOR =
75+
Comparator.comparing(PlanImplementation::getDeterministicIdentifier);
6876

6977
/**
7078
* {@link ExecutionOperator}s contained in this instance.
@@ -180,6 +188,14 @@ private PlanImplementation(PlanEnumeration planEnumeration,
180188
assert this.planEnumeration != null;
181189
}
182190

191+
public static Comparator<PlanImplementation> costComparator() {
192+
return COST_COMPARATOR;
193+
}
194+
195+
public static Comparator<PlanImplementation> structuralComparator() {
196+
return STRUCTURAL_COMPARATOR;
197+
}
198+
183199

184200
/**
185201
* @return the {@link PlanEnumeration} this instance belongs to
@@ -238,7 +254,7 @@ Collection<InputSlot<?>> findExecutionOperatorInputs(final InputSlot<?> someInpu
238254

239255
// Discern LoopHeadOperator InputSlots and loop body InputSlots.
240256
final List<LoopImplementation.IterationImplementation> iterationImpls = loopImplementation.getIterationImplementations();
241-
final Collection<InputSlot<?>> collector = new HashSet<>(innerInputs.size());
257+
final Collection<InputSlot<?>> collector = new LinkedHashSet<>(innerInputs.size());
242258
for (InputSlot<?> innerInput : innerInputs) {
243259
if (innerInput.getOwner() == loopSubplan.getLoopHead()) {
244260
final LoopImplementation.IterationImplementation initialIterationImpl = iterationImpls.get(0);
@@ -312,7 +328,7 @@ Collection<Tuple<OutputSlot<?>, PlanImplementation>> findExecutionOperatorOutput
312328
// For all the iterations, return the potential OutputSlots.
313329
final List<LoopImplementation.IterationImplementation> iterationImpls =
314330
loopImplementation.getIterationImplementations();
315-
final Set<Tuple<OutputSlot<?>, PlanImplementation>> collector = new HashSet<>(iterationImpls.size());
331+
final Set<Tuple<OutputSlot<?>, PlanImplementation>> collector = new LinkedHashSet<>(iterationImpls.size());
316332
for (LoopImplementation.IterationImplementation iterationImpl : iterationImpls) {
317333
final Collection<Tuple<OutputSlot<?>, PlanImplementation>> outputsWithContext =
318334
iterationImpl.getBodyImplementation().findExecutionOperatorOutputWithContext(innerOutput);
@@ -678,8 +694,8 @@ public double getSquashedCostEstimate() {
678694

679695
private Tuple<List<ProbabilisticDoubleInterval>, List<Double>> getParallelOperatorJunctionAllCostEstimate(Operator operator) {
680696

681-
Set<Operator> inputOperators = new HashSet<>();
682-
Set<Junction> inputJunction = new HashSet<>();
697+
Set<Operator> inputOperators = new LinkedHashSet<>();
698+
Set<Junction> inputJunction = new LinkedHashSet<>();
683699

684700
List<ProbabilisticDoubleInterval> probalisticCost = new ArrayList<>();
685701
List<Double> squashedCost = new ArrayList<>();
@@ -976,6 +992,67 @@ Stream<ExecutionOperator> streamOperators() {
976992
return operatorStream;
977993
}
978994

995+
/**
996+
* Provides a deterministic identifier that captures the current state of this plan. While not guaranteed to
997+
* be unique, it is stable across runs for the same logical plan and can therefore be used for reproducible
998+
* ordering.
999+
*
1000+
* @return the deterministic identifier
1001+
*/
1002+
public String getDeterministicIdentifier() {
1003+
final String operatorDescriptor = this.operators.stream()
1004+
.map(PlanImplementation::describeOperator)
1005+
.sorted()
1006+
.collect(Collectors.joining("|"));
1007+
final String junctionDescriptor = this.junctions.values().stream()
1008+
.map(PlanImplementation::describeJunction)
1009+
.sorted()
1010+
.collect(Collectors.joining("|"));
1011+
final String loopDescriptor = this.loopImplementations.entrySet().stream()
1012+
.map(entry -> describeLoop(entry.getKey(), entry.getValue()))
1013+
.sorted()
1014+
.collect(Collectors.joining("|"));
1015+
return operatorDescriptor + "#" + junctionDescriptor + "#" + loopDescriptor;
1016+
}
1017+
1018+
private static String describeOperator(Operator operator) {
1019+
final String name = operator.getName() == null ? "" : operator.getName();
1020+
return operator.getClass().getName() + ":" + name + ":" + operator.getEpoch();
1021+
}
1022+
1023+
private static String describeJunction(Junction junction) {
1024+
final String source = describeOutputSlot(junction.getSourceOutput());
1025+
final String targets = IntStream.range(0, junction.getNumTargets())
1026+
.mapToObj(i -> describeInputSlot(junction.getTargetInput(i)))
1027+
.sorted()
1028+
.collect(Collectors.joining(","));
1029+
return source + "->" + targets;
1030+
}
1031+
1032+
private static String describeLoop(LoopSubplan loop, LoopImplementation implementation) {
1033+
final String descriptor = describeOperator(loop);
1034+
final String iterationDescriptor = implementation.getIterationImplementations().stream()
1035+
.map(iteration -> Integer.toString(iteration.getNumIterations()))
1036+
.collect(Collectors.joining(","));
1037+
return descriptor + ":" + iterationDescriptor;
1038+
}
1039+
1040+
private static String describeInputSlot(InputSlot<?> slot) {
1041+
if (slot == null) {
1042+
return "null";
1043+
}
1044+
final Operator owner = slot.getOwner();
1045+
return describeOperator(owner) + ".in[" + slot.getIndex() + "]:" + slot.getName();
1046+
}
1047+
1048+
private static String describeOutputSlot(OutputSlot<?> slot) {
1049+
if (slot == null) {
1050+
return "null";
1051+
}
1052+
final Operator owner = slot.getOwner();
1053+
return describeOperator(owner) + ".out[" + slot.getIndex() + "]:" + slot.getName();
1054+
}
1055+
9791056
@Override
9801057
public String toString() {
9811058
return String.format("PlanImplementation[%s, %s, costs=%s]",

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/enumeration/TopKPruningStrategy.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,8 @@ public void prune(PlanEnumeration planEnumeration) {
4040
if (planEnumeration.getPlanImplementations().size() <= this.k) return;
4141

4242
ArrayList<PlanImplementation> planImplementations = new ArrayList<>(planEnumeration.getPlanImplementations());
43-
planImplementations.sort(this::comparePlanImplementations);
43+
planImplementations.sort(PlanImplementation.costComparator());
4444
planEnumeration.getPlanImplementations().retainAll(planImplementations.subList(0, this.k));
4545
}
4646

47-
48-
private int comparePlanImplementations(PlanImplementation p1,
49-
PlanImplementation p2) {
50-
final double t1 = p1.getSquashedCostEstimate(true);
51-
final double t2 = p2.getSquashedCostEstimate(true);
52-
return Double.compare(t1, t2);
53-
}
54-
5547
}

wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/util/MultiMap.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818

1919
package org.apache.wayang.core.util;
2020

21-
import java.util.HashMap;
22-
import java.util.HashSet;
21+
import java.util.LinkedHashMap;
22+
import java.util.LinkedHashSet;
2323
import java.util.Set;
2424

2525
/**
2626
* Maps keys to multiple values. Each key value pair is unique.
2727
*/
28-
public class MultiMap<K, V> extends HashMap<K, Set<V>> {
28+
public class MultiMap<K, V> extends LinkedHashMap<K, Set<V>> {
2929

3030
/**
3131
* Associate a key with a new value.
@@ -35,7 +35,7 @@ public class MultiMap<K, V> extends HashMap<K, Set<V>> {
3535
* @return whether the value was not yet associated with the key
3636
*/
3737
public boolean putSingle(K key, V value) {
38-
final Set<V> values = this.computeIfAbsent(key, k -> new HashSet<>());
38+
final Set<V> values = this.computeIfAbsent(key, k -> new LinkedHashSet<>());
3939
return values.add(value);
4040
}
4141

0 commit comments

Comments
 (0)