Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

/**
* This graph contains a set of {@link ChannelConversion}s.
Expand Down Expand Up @@ -168,7 +169,7 @@ private Tree mergeTrees(Collection<Tree> trees) {
final Tree firstTree = iterator.next();
Bitmask combinationSettledIndices = new Bitmask(firstTree.settledDestinationIndices);
int maxSettledIndices = combinationSettledIndices.cardinality();
final HashSet<ChannelDescriptor> employedChannelDescriptors = new HashSet<>(firstTree.employedChannelDescriptors);
final LinkedHashSet<ChannelDescriptor> employedChannelDescriptors = new LinkedHashSet<>(firstTree.employedChannelDescriptors);
int maxVisitedChannelDescriptors = employedChannelDescriptors.size();
double costs = firstTree.costs;
TreeVertex newRoot = new TreeVertex(firstTree.root.channelDescriptor, firstTree.root.settledIndices);
Expand Down Expand Up @@ -222,7 +223,11 @@ public static class CostbasedTreeSelectionStrategy implements TreeSelectionStrat

@Override
public Tree select(Tree t1, Tree t2) {
return t1.costs <= t2.costs ? t1 : t2;
int cmp = Double.compare(t1.costs, t2.costs);
if (cmp == 0) {
cmp = t1.getDeterministicSignature().compareTo(t2.getDeterministicSignature());
}
return cmp <= 0 ? t1 : t2;
}

}
Expand Down Expand Up @@ -381,7 +386,7 @@ private ShortestTreeSearcher(OutputSlot<?> sourceOutput,
this.existingDestinationChannelIndices = new Bitmask();

this.collectExistingChannels(sourceChannel);
this.openChannelDescriptors = new HashSet<>(openChannels.size());
this.openChannelDescriptors = new LinkedHashSet<>(openChannels.size());
for (Channel openChannel : openChannels) {
this.openChannelDescriptors.add(openChannel.getDescriptor());
}
Expand Down Expand Up @@ -477,7 +482,9 @@ private Set<ChannelDescriptor> resolveSupportedChannels(final InputSlot<?> input
final List<ChannelDescriptor> supportedInputChannels = owner.getSupportedInputChannels(input.getIndex());
if (input.isLoopInvariant()) {
// Loop input is needed in several iterations and must therefore be reusable.
return supportedInputChannels.stream().filter(ChannelDescriptor::isReusable).collect(Collectors.toSet());
return supportedInputChannels.stream()
.filter(ChannelDescriptor::isReusable)
.collect(Collectors.toCollection(LinkedHashSet::new));
} else {
return WayangCollections.asSet(supportedInputChannels);
}
Expand Down Expand Up @@ -546,7 +553,7 @@ private void kernelizeChannelRequests() {
}
if (channelDescriptors.size() - numReusableChannels == 1) {
iterator.remove();
channelDescriptors = new HashSet<>(channelDescriptors);
channelDescriptors = new LinkedHashSet<>(channelDescriptors);
channelDescriptors.removeIf(channelDescriptor -> !channelDescriptor.isReusable());
kernelDestChannelDescriptorSetsToIndicesUpdates.add(new Tuple<>(channelDescriptors, indices));
}
Expand Down Expand Up @@ -575,7 +582,7 @@ private void kernelizeChannelRequests() {
*/
private Tree searchTree() {
// Prepare the recursive traversal.
final HashSet<ChannelDescriptor> visitedChannelDescriptors = new HashSet<>(16);
final LinkedHashSet<ChannelDescriptor> visitedChannelDescriptors = new LinkedHashSet<>(16);
visitedChannelDescriptors.add(this.sourceChannelDescriptor);

// Perform the traversal.
Expand Down Expand Up @@ -777,7 +784,7 @@ private Set<ChannelDescriptor> getSuccessorChannelDescriptors(ChannelDescriptor
final Channel channel = this.existingChannels.get(descriptor);
if (channel == null || this.openChannelDescriptors.contains(descriptor)) return null;

Set<ChannelDescriptor> result = new HashSet<>();
Set<ChannelDescriptor> result = new LinkedHashSet<>();
for (ExecutionTask consumer : channel.getConsumers()) {
if (!consumer.getOperator().isAuxiliary()) continue;
for (Channel successorChannel : consumer.getOutputChannels()) {
Expand Down Expand Up @@ -988,7 +995,12 @@ private static class Tree {
*
* @see TreeVertex#channelDescriptor
*/
private final Set<ChannelDescriptor> employedChannelDescriptors = new HashSet<>();
private final Set<ChannelDescriptor> employedChannelDescriptors = new LinkedHashSet<>();

/**
* Cached deterministic signature for tie-breaking.
*/
private String deterministicSignature;

/**
* The sum of the costs of all {@link TreeEdge}s of this instance.
Expand All @@ -1010,6 +1022,7 @@ static Tree singleton(ChannelDescriptor channelDescriptor, Bitmask settledIndice
this.root = root;
this.settledDestinationIndices = settledDestinationIndices;
this.employedChannelDescriptors.add(root.channelDescriptor);
this.deterministicSignature = null;
}

/**
Expand All @@ -1033,6 +1046,21 @@ void reroot(ChannelDescriptor newRootChannelDescriptor,
this.employedChannelDescriptors.add(newRootChannelDescriptor);
this.settledDestinationIndices.orInPlace(newRootSettledIndices);
this.costs += edge.costEstimate;
this.deterministicSignature = null;
}

private String getDeterministicSignature() {
if (this.deterministicSignature == null) {
final String descriptorSignature = this.employedChannelDescriptors.stream()
.map(Object::toString)
.sorted()
.collect(Collectors.joining("|"));
final String indexSignature = StreamSupport.stream(this.settledDestinationIndices.spliterator(), false)
.map(String::valueOf)
.collect(Collectors.joining(","));
this.deterministicSignature = descriptorSignature + "#" + indexSignature;
}
return this.deterministicSignature;
}

@Override
Expand Down Expand Up @@ -1090,7 +1118,7 @@ private void copyEdgesFrom(TreeVertex that) {
* @return a {@link Set} of said {@link ChannelConversion}s
*/
private Set<ChannelConversion> getChildChannelConversions() {
Set<ChannelConversion> channelConversions = new HashSet<>();
Set<ChannelConversion> channelConversions = new LinkedHashSet<>();
for (TreeEdge edge : this.outEdges) {
channelConversions.add(edge.channelConversion);
channelConversions.addAll(edge.destination.getChildChannelConversions());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,7 @@ public class DefaultEstimatableCost implements EstimatableCost {
Set<ExecutionStage> executedStages
) {
final PlanImplementation bestPlanImplementation = executionPlans.stream()
.reduce((p1, p2) -> {
final double t1 = p1.getSquashedCostEstimate();
final double t2 = p2.getSquashedCostEstimate();
return t1 < t2 ? p1 : p2;
})
.min(PlanImplementation.costComparator())
.orElseThrow(() -> new WayangException("Could not find an execution plan."));
return bestPlanImplementation;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private PlanImplementation selectBestPlanBinary(PlanImplementation p1,
PlanImplementation p2) {
final double t1 = p1.getSquashedCostEstimate(true);
final double t2 = p2.getSquashedCostEstimate(true);
final boolean isPickP1 = t1 <= t2;
final boolean isPickP1 = PlanImplementation.costComparator().compare(p1, p2) <= 0;
if (logger.isDebugEnabled()) {
if (isPickP1) {
LogManager.getLogger(LatentOperatorPruningStrategy.class).debug(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -91,7 +91,7 @@ public class PlanEnumeration {
* Creates a new instance.
*/
public PlanEnumeration() {
this(new HashSet<>(), new HashSet<>(), new HashSet<>());
this(new LinkedHashSet<>(), new LinkedHashSet<>(), new LinkedHashSet<>());
}

/**
Expand Down Expand Up @@ -412,16 +412,15 @@ private Collection<PlanImplementation> concatenatePartialPlansBatchwise(
if (junction == null) continue;

// If we found a junction, then we can enumerate all PlanImplementation combinations.
final List<Set<PlanImplementation>> groupPlans = WayangCollections.map(
concatGroupCombo,
concatGroup -> {
Set<PlanImplementation.ConcatenationDescriptor> concatDescriptors = concatGroup2concatDescriptor.get(concatGroup);
Set<PlanImplementation> planImplementations = new HashSet<>(concatDescriptors.size());
for (PlanImplementation.ConcatenationDescriptor concatDescriptor : concatDescriptors) {
planImplementations.add(concatDescriptor.getPlanImplementation());
}
return planImplementations;
});
final List<List<PlanImplementation>> groupPlans = WayangCollections.map(
concatGroupCombo,
concatGroup -> {
Set<PlanImplementation.ConcatenationDescriptor> concatDescriptors = concatGroup2concatDescriptor.get(concatGroup);
return concatDescriptors.stream()
.map(PlanImplementation.ConcatenationDescriptor::getPlanImplementation)
.sorted(PlanImplementation.structuralComparator())
.collect(Collectors.toList());
});

for (List<PlanImplementation> planCombo : WayangCollections.streamedCrossProduct(groupPlans)) {
PlanImplementation basePlan = planCombo.get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,19 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.ToDoubleFunction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
Expand All @@ -65,6 +68,11 @@
public class PlanImplementation {

private static final Logger logger = LogManager.getLogger(PlanImplementation.class);
private static final Comparator<PlanImplementation> COST_COMPARATOR =
Comparator.comparingDouble((PlanImplementation plan) -> plan.getSquashedCostEstimate(true))
.thenComparing(PlanImplementation::getDeterministicIdentifier);
private static final Comparator<PlanImplementation> STRUCTURAL_COMPARATOR =
Comparator.comparing(PlanImplementation::getDeterministicIdentifier);

/**
* {@link ExecutionOperator}s contained in this instance.
Expand Down Expand Up @@ -180,6 +188,14 @@ private PlanImplementation(PlanEnumeration planEnumeration,
assert this.planEnumeration != null;
}

public static Comparator<PlanImplementation> costComparator() {
return COST_COMPARATOR;
}

public static Comparator<PlanImplementation> structuralComparator() {
return STRUCTURAL_COMPARATOR;
}


/**
* @return the {@link PlanEnumeration} this instance belongs to
Expand Down Expand Up @@ -238,7 +254,7 @@ Collection<InputSlot<?>> findExecutionOperatorInputs(final InputSlot<?> someInpu

// Discern LoopHeadOperator InputSlots and loop body InputSlots.
final List<LoopImplementation.IterationImplementation> iterationImpls = loopImplementation.getIterationImplementations();
final Collection<InputSlot<?>> collector = new HashSet<>(innerInputs.size());
final Collection<InputSlot<?>> collector = new LinkedHashSet<>(innerInputs.size());
for (InputSlot<?> innerInput : innerInputs) {
if (innerInput.getOwner() == loopSubplan.getLoopHead()) {
final LoopImplementation.IterationImplementation initialIterationImpl = iterationImpls.get(0);
Expand Down Expand Up @@ -312,7 +328,7 @@ Collection<Tuple<OutputSlot<?>, PlanImplementation>> findExecutionOperatorOutput
// For all the iterations, return the potential OutputSlots.
final List<LoopImplementation.IterationImplementation> iterationImpls =
loopImplementation.getIterationImplementations();
final Set<Tuple<OutputSlot<?>, PlanImplementation>> collector = new HashSet<>(iterationImpls.size());
final Set<Tuple<OutputSlot<?>, PlanImplementation>> collector = new LinkedHashSet<>(iterationImpls.size());
for (LoopImplementation.IterationImplementation iterationImpl : iterationImpls) {
final Collection<Tuple<OutputSlot<?>, PlanImplementation>> outputsWithContext =
iterationImpl.getBodyImplementation().findExecutionOperatorOutputWithContext(innerOutput);
Expand Down Expand Up @@ -678,8 +694,8 @@ public double getSquashedCostEstimate() {

private Tuple<List<ProbabilisticDoubleInterval>, List<Double>> getParallelOperatorJunctionAllCostEstimate(Operator operator) {

Set<Operator> inputOperators = new HashSet<>();
Set<Junction> inputJunction = new HashSet<>();
Set<Operator> inputOperators = new LinkedHashSet<>();
Set<Junction> inputJunction = new LinkedHashSet<>();

List<ProbabilisticDoubleInterval> probalisticCost = new ArrayList<>();
List<Double> squashedCost = new ArrayList<>();
Expand Down Expand Up @@ -976,6 +992,67 @@ Stream<ExecutionOperator> streamOperators() {
return operatorStream;
}

/**
* Provides a deterministic identifier that captures the current state of this plan. While not guaranteed to
* be unique, it is stable across runs for the same logical plan and can therefore be used for reproducible
* ordering.
*
* @return the deterministic identifier
*/
public String getDeterministicIdentifier() {
final String operatorDescriptor = this.operators.stream()
.map(PlanImplementation::describeOperator)
.sorted()
.collect(Collectors.joining("|"));
final String junctionDescriptor = this.junctions.values().stream()
.map(PlanImplementation::describeJunction)
.sorted()
.collect(Collectors.joining("|"));
final String loopDescriptor = this.loopImplementations.entrySet().stream()
.map(entry -> describeLoop(entry.getKey(), entry.getValue()))
.sorted()
.collect(Collectors.joining("|"));
return operatorDescriptor + "#" + junctionDescriptor + "#" + loopDescriptor;
}

private static String describeOperator(Operator operator) {
final String name = operator.getName() == null ? "" : operator.getName();
return operator.getClass().getName() + ":" + name + ":" + operator.getEpoch();
}

private static String describeJunction(Junction junction) {
final String source = describeOutputSlot(junction.getSourceOutput());
final String targets = IntStream.range(0, junction.getNumTargets())
.mapToObj(i -> describeInputSlot(junction.getTargetInput(i)))
.sorted()
.collect(Collectors.joining(","));
return source + "->" + targets;
}

private static String describeLoop(LoopSubplan loop, LoopImplementation implementation) {
final String descriptor = describeOperator(loop);
final String iterationDescriptor = implementation.getIterationImplementations().stream()
.map(iteration -> Integer.toString(iteration.getNumIterations()))
.collect(Collectors.joining(","));
return descriptor + ":" + iterationDescriptor;
}

private static String describeInputSlot(InputSlot<?> slot) {
if (slot == null) {
return "null";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this just return an actual null then? String is nullable, so this can be a valid value.
Or would this be a problem on comparison later on?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning a "real" null would break the deterministic identifier: the values are .sorted() and then joined, and sorted() throws on nulls. The literal "null" keeps it stable and safe for comparison/concatenation. We could switch to real nulls only if we add null-safe sorting/joining.

}
final Operator owner = slot.getOwner();
return describeOperator(owner) + ".in[" + slot.getIndex() + "]:" + slot.getName();
}

private static String describeOutputSlot(OutputSlot<?> slot) {
if (slot == null) {
return "null";
}
final Operator owner = slot.getOwner();
return describeOperator(owner) + ".out[" + slot.getIndex() + "]:" + slot.getName();
}

@Override
public String toString() {
return String.format("PlanImplementation[%s, %s, costs=%s]",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,8 @@ public void prune(PlanEnumeration planEnumeration) {
if (planEnumeration.getPlanImplementations().size() <= this.k) return;

ArrayList<PlanImplementation> planImplementations = new ArrayList<>(planEnumeration.getPlanImplementations());
planImplementations.sort(this::comparePlanImplementations);
planImplementations.sort(PlanImplementation.costComparator());
planEnumeration.getPlanImplementations().retainAll(planImplementations.subList(0, this.k));
}


private int comparePlanImplementations(PlanImplementation p1,
PlanImplementation p2) {
final double t1 = p1.getSquashedCostEstimate(true);
final double t2 = p2.getSquashedCostEstimate(true);
return Double.compare(t1, t2);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

package org.apache.wayang.core.util;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Set;

/**
* Maps keys to multiple values. Each key value pair is unique.
*/
public class MultiMap<K, V> extends HashMap<K, Set<V>> {
public class MultiMap<K, V> extends LinkedHashMap<K, Set<V>> {

/**
* Associate a key with a new value.
Expand All @@ -35,7 +35,7 @@ public class MultiMap<K, V> extends HashMap<K, Set<V>> {
* @return whether the value was not yet associated with the key
*/
public boolean putSingle(K key, V value) {
final Set<V> values = this.computeIfAbsent(key, k -> new HashSet<>());
final Set<V> values = this.computeIfAbsent(key, k -> new LinkedHashSet<>());
return values.add(value);
}

Expand Down
Loading
Loading