diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java index b88e1459b2e..57611ffee66 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java @@ -16,6 +16,7 @@ */ package org.apache.calcite.rel.rules; +import org.apache.calcite.linq4j.function.Experimental; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.Calc; @@ -816,4 +817,17 @@ private CoreRules() {} WINDOW_REDUCE_EXPRESSIONS = ReduceExpressionsRule.WindowReduceExpressionsRule.WindowReduceExpressionsRuleConfig .DEFAULT.toRule(); + + /** Rule that flattens a tree of {@link LogicalJoin}s + * into a single {@link HyperGraph} with N inputs. */ + @Experimental + public static final JoinToHyperGraphRule JOIN_TO_HYPER_GRAPH = + JoinToHyperGraphRule.Config.DEFAULT.toRule(); + + /** Rule that re-orders a {@link Join} tree using dphyp algorithm. + * + * @see #JOIN_TO_HYPER_GRAPH */ + @Experimental + public static final DphypJoinReorderRule HYPER_GRAPH_OPTIMIZE = + DphypJoinReorderRule.Config.DEFAULT.toRule(); } diff --git a/core/src/main/java/org/apache/calcite/rel/rules/DpHyp.java b/core/src/main/java/org/apache/calcite/rel/rules/DpHyp.java new file mode 100644 index 00000000000..63c82b9efc3 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/DpHyp.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.linq4j.function.Experimental; +import org.apache.calcite.plan.RelOptCost; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.tools.RelBuilder; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.util.HashMap; +import java.util.List; + +/** + * The core process of dphyp enumeration algorithm. + */ +@Experimental +public class DpHyp { + + private final HyperGraph hyperGraph; + + private final HashMap dpTable; + + private final RelBuilder builder; + + private final RelMetadataQuery mq; + + public DpHyp(HyperGraph hyperGraph, RelBuilder builder, RelMetadataQuery relMetadataQuery) { + this.hyperGraph = + hyperGraph.copy( + hyperGraph.getTraitSet(), + hyperGraph.getInputs()); + this.dpTable = new HashMap<>(); + this.builder = builder; + this.mq = relMetadataQuery; + // make all field name unique and convert the + // HyperEdge condition from RexInputRef to RexInputFieldName + this.hyperGraph.convertHyperEdgeCond(builder); + } + + /** + * The entry function of the algorithm. We use a bitmap to represent a leaf node, + * which indicates the position of the corresponding leaf node in {@link HyperGraph}. + * + *

After the enumeration is completed, the best join order will be stored + * in the {@link DpHyp#dpTable}. + */ + public void startEnumerateJoin() { + int size = hyperGraph.getInputs().size(); + for (int i = 0; i < size; i++) { + long singleNode = LongBitmap.newBitmap(i); + dpTable.put(singleNode, hyperGraph.getInput(i)); + hyperGraph.initEdgeBitMap(singleNode); + } + + // start enumerating from the second to last + for (int i = size - 2; i >= 0; i--) { + long csg = LongBitmap.newBitmap(i); + long forbidden = csg - 1; + emitCsg(csg); + enumerateCsgRec(csg, forbidden); + } + } + + /** + * Given a connected subgraph (csg), enumerate all possible complements subgraph (cmp) + * that do not include anything from the exclusion subset. + * + *

Corresponding to EmitCsg in origin paper. + */ + private void emitCsg(long csg) { + long forbidden = csg | LongBitmap.getBvBitmap(csg); + long neighbors = hyperGraph.getNeighborBitmap(csg, forbidden); + + LongBitmap.ReverseIterator reverseIterator = new LongBitmap.ReverseIterator(neighbors); + for (long cmp : reverseIterator) { + List edges = hyperGraph.connectCsgCmp(csg, cmp); + if (!edges.isEmpty()) { + emitCsgCmp(csg, cmp, edges); + } + // forbidden the nodes that smaller than current cmp when extend cmp, e.g. + // neighbors = {t1, t2}, t1 and t2 are connected. + // when extented t2, we will get (t1, t2) + // when extented t1, we will get (t1, t2) repeated + long newForbidden = + (cmp | LongBitmap.getBvBitmap(cmp)) & neighbors; + newForbidden = newForbidden | forbidden; + enumerateCmpRec(csg, cmp, newForbidden); + } + } + + /** + * Given a connected subgraph (csg), expands it recursively by its neighbors. + * If the expanded csg is connected, try to enumerate its cmp (note that for complex hyperedge, + * we only select a single representative node to add to the neighbors, so csg and subNeighbor + * are not necessarily connected. However, it still needs to be expanded to prevent missing + * complex hyperedge). This method is called after the enumeration of csg is completed, + * that is, after {@link DpHyp#emitCsg(long csg)}. + * + *

Corresponding to EnumerateCsgRec in origin paper. + */ + private void enumerateCsgRec(long csg, long forbidden) { + long neighbors = hyperGraph.getNeighborBitmap(csg, forbidden); + LongBitmap.SubsetIterator subsetIterator = new LongBitmap.SubsetIterator(neighbors); + for (long subNeighbor : subsetIterator) { + hyperGraph.updateEdgesForUnion(csg, subNeighbor); + long newCsg = csg | subNeighbor; + if (dpTable.containsKey(newCsg)) { + emitCsg(newCsg); + } + } + long newForbidden = forbidden | neighbors; + subsetIterator.reset(); + for (long subNeighbor : subsetIterator) { + long newCsg = csg | subNeighbor; + enumerateCsgRec(newCsg, newForbidden); + } + } + + /** + * Given a connected subgraph (csg) and its complement subgraph (cmp), expands the cmp + * recursively by neighbors of cmp (cmp and subNeighbor are not necessarily connected, + * which is the same logic as in {@link DpHyp#enumerateCsgRec}). + * + *

Corresponding to EnumerateCmpRec in origin paper. + */ + private void enumerateCmpRec(long csg, long cmp, long forbidden) { + long neighbors = hyperGraph.getNeighborBitmap(cmp, forbidden); + LongBitmap.SubsetIterator subsetIterator = new LongBitmap.SubsetIterator(neighbors); + for (long subNeighbor : subsetIterator) { + long newCmp = cmp | subNeighbor; + hyperGraph.updateEdgesForUnion(cmp, subNeighbor); + if (dpTable.containsKey(newCmp)) { + List edges = hyperGraph.connectCsgCmp(csg, newCmp); + if (!edges.isEmpty()) { + emitCsgCmp(csg, newCmp, edges); + } + } + } + long newForbidden = forbidden | neighbors; + subsetIterator.reset(); + for (long subNeighbor : subsetIterator) { + long newCmp = cmp | subNeighbor; + enumerateCmpRec(csg, newCmp, newForbidden); + } + } + + /** + * Given a connected csg-cmp pair and the hyperedges that connect them, build the + * corresponding Join plan. If the new Join plan is better than the existing plan, + * update the {@link DpHyp#dpTable}. + * + *

Corresponding to EmitCsgCmp in origin paper. + */ + private void emitCsgCmp(long csg, long cmp, List edges) { + RelNode child1 = dpTable.get(csg); + RelNode child2 = dpTable.get(cmp); + if (child1 == null || child2 == null) { + throw new IllegalArgumentException( + "csg and cmp were not enumerated in the previous dp process"); + } + + JoinRelType joinType = hyperGraph.extractJoinType(edges); + if (joinType == null) { + return; + } + RexNode joinCond1 = hyperGraph.extractJoinCond(child1, child2, edges); + RelNode newPlan1 = builder + .push(child1) + .push(child2) + .join(joinType, joinCond1) + .build(); + + // swap left and right + RexNode joinCond2 = hyperGraph.extractJoinCond(child2, child1, edges); + RelNode newPlan2 = builder + .push(child2) + .push(child1) + .join(joinType, joinCond2) + .build(); + RelNode winPlan = chooseBetterPlan(newPlan1, newPlan2); + + RelNode oriPlan = dpTable.get(csg | cmp); + if (oriPlan != null) { + winPlan = chooseBetterPlan(winPlan, oriPlan); + } + dpTable.put(csg | cmp, winPlan); + } + + public @Nullable RelNode getBestPlan() { + int size = hyperGraph.getInputs().size(); + long wholeGraph = LongBitmap.newBitmapBetween(0, size); + return dpTable.get(wholeGraph); + } + + private RelNode chooseBetterPlan(RelNode plan1, RelNode plan2) { + RelOptCost cost1 = mq.getCumulativeCost(plan1); + RelOptCost cost2 = mq.getCumulativeCost(plan2); + if (cost1 != null && cost2 != null) { + return cost1.isLt(cost2) ? plan1 : plan2; + } else if (cost1 != null) { + return plan1; + } else { + return plan2; + } + } + +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/DphypJoinReorderRule.java b/core/src/main/java/org/apache/calcite/rel/rules/DphypJoinReorderRule.java new file mode 100644 index 00000000000..8b4a5177224 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/DphypJoinReorderRule.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.linq4j.function.Experimental; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.tools.RelBuilder; + +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.List; + +/** Rule that re-orders a {@link Join} tree using dphyp algorithm. + * + * @see CoreRules#HYPER_GRAPH_OPTIMIZE */ +@Value.Enclosing +@Experimental +public class DphypJoinReorderRule + extends RelRule + implements TransformationRule { + + protected DphypJoinReorderRule(Config config) { + super(config); + } + + @Override public void onMatch(RelOptRuleCall call) { + HyperGraph hyperGraph = call.rel(0); + RelBuilder relBuilder = call.builder(); + + // enumerate by Dphyp + DpHyp dpHyp = new DpHyp(hyperGraph, relBuilder, call.getMetadataQuery()); + dpHyp.startEnumerateJoin(); + RelNode orderedJoin = dpHyp.getBestPlan(); + if (orderedJoin == null) { + return; + } + + // permute field to origin order + List oriNames = hyperGraph.getRowType().getFieldNames(); + List newNames = orderedJoin.getRowType().getFieldNames(); + List projects = new ArrayList<>(); + RexBuilder rexBuilder = hyperGraph.getCluster().getRexBuilder(); + for (String oriName : oriNames) { + projects.add(rexBuilder.makeInputRef(orderedJoin, newNames.indexOf(oriName))); + } + + RelNode result = call.builder() + .push(orderedJoin) + .project(projects) + .build(); + call.transformTo(result); + } + + /** Rule configuration. */ + @Value.Immutable + public interface Config extends RelRule.Config { + Config DEFAULT = ImmutableDphypJoinReorderRule.Config.of() + .withOperandSupplier(b1 -> + b1.operand(HyperGraph.class).anyInputs()); + + @Override default DphypJoinReorderRule toRule() { + return new DphypJoinReorderRule(this); + } + } +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/HyperEdge.java b/core/src/main/java/org/apache/calcite/rel/rules/HyperEdge.java new file mode 100644 index 00000000000..f9f85a43e65 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/HyperEdge.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.linq4j.function.Experimental; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; + +/** + * Edge in HyperGraph, that represents a join predicate. + */ +@Experimental +public class HyperEdge { + + private final long leftNodeBits; + + private final long rightNodeBits; + + private final JoinRelType joinType; + + private final boolean isSimple; + + private final RexNode condition; + + public HyperEdge(long leftNodeBits, long rightNodeBits, JoinRelType joinType, RexNode condition) { + this.leftNodeBits = leftNodeBits; + this.rightNodeBits = rightNodeBits; + this.joinType = joinType; + this.condition = condition; + boolean leftSimple = (leftNodeBits & (leftNodeBits - 1)) == 0; + boolean rightSimple = (rightNodeBits & (rightNodeBits - 1)) == 0; + this.isSimple = leftSimple && rightSimple; + } + + public long getNodeBitmap() { + return leftNodeBits | rightNodeBits; + } + + public long getLeftNodeBitmap() { + return leftNodeBits; + } + + public long getRightNodeBitmap() { + return rightNodeBits; + } + + // hyperedge (u, v) is simple if |u| = |v| = 1 + public boolean isSimple() { + return isSimple; + } + + public JoinRelType getJoinType() { + return joinType; + } + + public RexNode getCondition() { + return condition; + } + + @Override public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(LongBitmap.printBitmap(leftNodeBits)) + .append("——[").append(joinType).append(", ").append(condition).append("]——") + .append(LongBitmap.printBitmap(rightNodeBits)); + return sb.toString(); + } + +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/HyperGraph.java b/core/src/main/java/org/apache/calcite/rel/rules/HyperGraph.java new file mode 100644 index 00000000000..d8fe0fd2ff9 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/HyperGraph.java @@ -0,0 +1,492 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.linq4j.function.Experimental; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.AbstractRelNode; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBiVisitor; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVariable; +import org.apache.calcite.rex.RexVisitor; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.ImmutableBitSet; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; + +/** + * HyperGraph represents a join graph. + */ +@Experimental +public class HyperGraph extends AbstractRelNode { + + private final List inputs; + + @SuppressWarnings("HidingField") + private final RelDataType rowType; + + private final List edges; + + // record the indices of complex hyper edges in the 'edges' + private final ImmutableBitSet complexEdgesBitmap; + + /** + * For the HashMap fields, key is the bitmap for inputs, + * value is the hyper edge bitmap in edges. + */ + // record which hyper edges have been used by the enumerated csg-cmp pairs + private final HashMap ccpUsedEdgesMap; + + private final HashMap simpleEdgesMap; + + private final HashMap complexEdgesMap; + + // node bitmap overlaps edge's leftNodeBits or rightNodeBits, but does not completely cover + private final HashMap overlapEdgesMap; + + protected HyperGraph(RelOptCluster cluster, + RelTraitSet traitSet, + List inputs, + List edges, + RelDataType rowType) { + super(cluster, traitSet); + this.inputs = Lists.newArrayList(inputs); + this.edges = Lists.newArrayList(edges); + this.rowType = rowType; + ImmutableBitSet.Builder bitSetBuilder = ImmutableBitSet.builder(); + for (int i = 0; i < edges.size(); i++) { + if (!edges.get(i).isSimple()) { + bitSetBuilder.set(i); + } + } + this.complexEdgesBitmap = bitSetBuilder.build(); + this.ccpUsedEdgesMap = new HashMap<>(); + this.simpleEdgesMap = new HashMap<>(); + this.complexEdgesMap = new HashMap<>(); + this.overlapEdgesMap = new HashMap<>(); + } + + protected HyperGraph(RelOptCluster cluster, + RelTraitSet traitSet, + List inputs, + List edges, + RelDataType rowType, + ImmutableBitSet complexEdgesBitmap, + HashMap ccpUsedEdgesMap, + HashMap simpleEdgesMap, + HashMap complexEdgesMap, + HashMap overlapEdgesMap) { + super(cluster, traitSet); + this.inputs = Lists.newArrayList(inputs); + this.edges = Lists.newArrayList(edges); + this.rowType = rowType; + this.complexEdgesBitmap = complexEdgesBitmap; + this.ccpUsedEdgesMap = new HashMap<>(ccpUsedEdgesMap); + this.simpleEdgesMap = new HashMap<>(simpleEdgesMap); + this.complexEdgesMap = new HashMap<>(complexEdgesMap); + this.overlapEdgesMap = new HashMap<>(overlapEdgesMap); + } + + @Override public HyperGraph copy(RelTraitSet traitSet, List inputs) { + return new HyperGraph( + getCluster(), + traitSet, + inputs, + edges, + rowType, + complexEdgesBitmap, + ccpUsedEdgesMap, + simpleEdgesMap, + complexEdgesMap, + overlapEdgesMap); + } + + @Override public RelWriter explainTerms(RelWriter pw) { + super.explainTerms(pw); + for (Ord ord : Ord.zip(inputs)) { + pw.input("input#" + ord.i, ord.e); + } + List hyperEdges = edges.stream() + .map(hyperEdge -> hyperEdge.toString()) + .collect(Collectors.toList()); + pw.item("edges", String.join(",", hyperEdges)); + return pw; + } + + @Override public List getInputs() { + return inputs; + } + + @Override public void replaceInput(int ordinalInParent, RelNode p) { + inputs.set(ordinalInParent, p); + recomputeDigest(); + } + + @Override public RelDataType deriveRowType() { + return rowType; + } + + @Override public RelNode accept(RexShuttle shuttle) { + List shuttleEdges = new ArrayList<>(); + for (HyperEdge edge : edges) { + HyperEdge shuttleEdge = + new HyperEdge( + edge.getLeftNodeBitmap(), + edge.getRightNodeBitmap(), + edge.getJoinType(), + shuttle.apply(edge.getCondition())); + shuttleEdges.add(shuttleEdge); + } + + return new HyperGraph( + getCluster(), + traitSet, + inputs, + shuttleEdges, + rowType, + complexEdgesBitmap, + ccpUsedEdgesMap, + simpleEdgesMap, + complexEdgesMap, + overlapEdgesMap); + } + + //~ hyper graph method ---------------------------------------------------------- + + public List getEdges() { + return edges; + } + + public long getNeighborBitmap(long csg, long forbidden) { + long neighbors = 0L; + List simpleEdges = simpleEdgesMap.getOrDefault(csg, new BitSet()).stream() + .mapToObj(edges::get) + .collect(Collectors.toList()); + for (HyperEdge edge : simpleEdges) { + neighbors |= edge.getNodeBitmap(); + } + + forbidden = forbidden | csg; + neighbors = neighbors & ~forbidden; + forbidden = forbidden | neighbors; + + List complexEdges = complexEdgesMap.getOrDefault(csg, new BitSet()).stream() + .mapToObj(edges::get) + .collect(Collectors.toList()); + for (HyperEdge edge : complexEdges) { + long leftBitmap = edge.getLeftNodeBitmap(); + long rightBitmap = edge.getRightNodeBitmap(); + if (LongBitmap.isSubSet(leftBitmap, csg) && !LongBitmap.isOverlap(rightBitmap, forbidden)) { + neighbors |= Long.lowestOneBit(rightBitmap); + } else if (LongBitmap.isSubSet(rightBitmap, csg) + && !LongBitmap.isOverlap(leftBitmap, forbidden)) { + neighbors |= Long.lowestOneBit(leftBitmap); + } + } + return neighbors; + } + + /** + * If csg and cmp are connected, return the edges that connect them. + */ + public List connectCsgCmp(long csg, long cmp) { + checkArgument(simpleEdgesMap.containsKey(csg)); + checkArgument(simpleEdgesMap.containsKey(cmp)); + List connectedEdges = new ArrayList<>(); + BitSet connectedEdgesBitmap = new BitSet(); + connectedEdgesBitmap.or(simpleEdgesMap.getOrDefault(csg, new BitSet())); + connectedEdgesBitmap.or(complexEdgesMap.getOrDefault(csg, new BitSet())); + + BitSet cmpEdgesBitmap = new BitSet(); + cmpEdgesBitmap.or(simpleEdgesMap.getOrDefault(cmp, new BitSet())); + cmpEdgesBitmap.or(complexEdgesMap.getOrDefault(cmp, new BitSet())); + connectedEdgesBitmap.and(cmpEdgesBitmap); + + // only consider the records related to csg and cmp in the simpleEdgesMap/complexEdgesMap, + // may omit some complex hyper edges. e.g. + // csg = {t1, t3}, cmp = {t2}, will omit the edge (t1, t2)——(t3) + BitSet mayMissedEdges = new BitSet(); + mayMissedEdges.or(complexEdgesBitmap.toBitSet()); + mayMissedEdges.andNot(ccpUsedEdgesMap.getOrDefault(csg, new BitSet())); + mayMissedEdges.andNot(ccpUsedEdgesMap.getOrDefault(cmp, new BitSet())); + mayMissedEdges.andNot(connectedEdgesBitmap); + mayMissedEdges.stream() + .forEach(index -> { + HyperEdge edge = edges.get(index); + if (LongBitmap.isSubSet(edge.getNodeBitmap(), csg | cmp)) { + connectedEdgesBitmap.set(index); + } + }); + + // record hyper edges are used by current csg ∪ cmp + BitSet curUsedEdges = new BitSet(); + curUsedEdges.or(connectedEdgesBitmap); + curUsedEdges.or(ccpUsedEdgesMap.getOrDefault(csg, new BitSet())); + curUsedEdges.or(ccpUsedEdgesMap.getOrDefault(cmp, new BitSet())); + if (ccpUsedEdgesMap.containsKey(csg | cmp)) { + checkArgument( + curUsedEdges.equals(ccpUsedEdgesMap.get(csg | cmp))); + } + ccpUsedEdgesMap.put(csg | cmp, curUsedEdges); + + connectedEdgesBitmap.stream() + .forEach(index -> connectedEdges.add(edges.get(index))); + return connectedEdges; + } + + public void initEdgeBitMap(long subset) { + BitSet simpleBitSet = new BitSet(); + BitSet complexBitSet = new BitSet(); + BitSet overlapBitSet = new BitSet(); + for (int i = 0; i < edges.size(); i++) { + HyperEdge edge = edges.get(i); + if (isAccurateEdge(edge, subset)) { + if (edge.isSimple()) { + simpleBitSet.set(i); + } else { + complexBitSet.set(i); + } + } else if (isOverlapEdge(edge, subset)) { + overlapBitSet.set(i); + } + } + simpleEdgesMap.put(subset, simpleBitSet); + complexEdgesMap.put(subset, complexBitSet); + overlapEdgesMap.put(subset, overlapBitSet); + } + + public void updateEdgesForUnion(long subset1, long subset2) { + if (!simpleEdgesMap.containsKey(subset1)) { + initEdgeBitMap(subset1); + } + if (!simpleEdgesMap.containsKey(subset2)) { + initEdgeBitMap(subset2); + } + long unionSet = subset1 | subset2; + if (simpleEdgesMap.containsKey(unionSet)) { + return; + } + + BitSet unionSimpleBitSet = new BitSet(); + unionSimpleBitSet.or(simpleEdgesMap.getOrDefault(subset1, new BitSet())); + unionSimpleBitSet.or(simpleEdgesMap.getOrDefault(subset2, new BitSet())); + + BitSet unionComplexBitSet = new BitSet(); + unionComplexBitSet.or(complexEdgesMap.getOrDefault(subset1, new BitSet())); + unionComplexBitSet.or(complexEdgesMap.getOrDefault(subset2, new BitSet())); + + BitSet unionOverlapBitSet = new BitSet(); + unionOverlapBitSet.or(overlapEdgesMap.getOrDefault(subset1, new BitSet())); + unionOverlapBitSet.or(overlapEdgesMap.getOrDefault(subset2, new BitSet())); + + // the overlaps edge that belongs to subset1/subset2 + // may be complex edge for subset1 union subset2 + for (int index : unionOverlapBitSet.stream().toArray()) { + HyperEdge edge = edges.get(index); + if (isAccurateEdge(edge, unionSet)) { + unionComplexBitSet.set(index); + unionOverlapBitSet.set(index, false); + } + } + + // remove cycle in subset1 union subset2 + for (int index : unionSimpleBitSet.stream().toArray()) { + HyperEdge edge = edges.get(index); + if (!isAccurateEdge(edge, unionSet)) { + unionSimpleBitSet.set(index, false); + } + } + for (int index : unionComplexBitSet.stream().toArray()) { + HyperEdge edge = edges.get(index); + if (!isAccurateEdge(edge, unionSet)) { + unionComplexBitSet.set(index, false); + } + } + + simpleEdgesMap.put(unionSet, unionSimpleBitSet); + complexEdgesMap.put(unionSet, unionComplexBitSet); + overlapEdgesMap.put(unionSet, unionOverlapBitSet); + } + + private static boolean isAccurateEdge(HyperEdge edge, long subset) { + boolean isLeftEnd = LongBitmap.isSubSet(edge.getLeftNodeBitmap(), subset) + && !LongBitmap.isOverlap(edge.getRightNodeBitmap(), subset); + boolean isRightEnd = LongBitmap.isSubSet(edge.getRightNodeBitmap(), subset) + && !LongBitmap.isOverlap(edge.getLeftNodeBitmap(), subset); + return isLeftEnd || isRightEnd; + } + + private static boolean isOverlapEdge(HyperEdge edge, long subset) { + boolean isLeftEnd = LongBitmap.isOverlap(edge.getLeftNodeBitmap(), subset) + && !LongBitmap.isOverlap(edge.getRightNodeBitmap(), subset); + boolean isRightEnd = LongBitmap.isOverlap(edge.getRightNodeBitmap(), subset) + && !LongBitmap.isOverlap(edge.getLeftNodeBitmap(), subset); + return isLeftEnd || isRightEnd; + } + + public @Nullable JoinRelType extractJoinType(List edges) { + JoinRelType joinType = edges.get(0).getJoinType(); + for (int i = 1; i < edges.size(); i++) { + if (edges.get(i).getJoinType() != joinType) { + return null; + } + } + return joinType; + } + + public RexNode extractJoinCond(RelNode left, RelNode right, List edges) { + List joinConds = new ArrayList<>(); + List fieldList = new ArrayList<>(left.getRowType().getFieldList()); + fieldList.addAll(right.getRowType().getFieldList()); + + List names = new ArrayList<>(left.getRowType().getFieldNames()); + names.addAll(right.getRowType().getFieldNames()); + + // convert the HyperEdge's condition from RexInputFieldName to RexInputRef + RexShuttle inputName2InputRefShuttle = new RexShuttle() { + @Override protected List visitList( + List exprs, + boolean @Nullable [] update) { + ImmutableList.Builder clonedOperands = ImmutableList.builder(); + for (RexNode operand : exprs) { + RexNode clonedOperand; + if (operand instanceof RexInputFieldName) { + int index = names.indexOf(((RexInputFieldName) operand).getName()); + clonedOperand = new RexInputRef(index, fieldList.get(index).getType()); + } else { + clonedOperand = operand.accept(this); + } + if ((clonedOperand != operand) && (update != null)) { + update[0] = true; + } + clonedOperands.add(clonedOperand); + } + return clonedOperands.build(); + } + }; + + for (HyperEdge edge : edges) { + RexNode inputRefCond = edge.getCondition().accept(inputName2InputRefShuttle); + joinConds.add(inputRefCond); + } + return RexUtil.composeConjunction(left.getCluster().getRexBuilder(), joinConds); + } + + /** + * Before starting enumeration, add Project on every input, make all field name unique. + * Convert the HyperEdge condition from RexInputRef to RexInputFieldName + */ + public void convertHyperEdgeCond(RelBuilder builder) { + int fieldIndex = 0; + List fieldList = rowType.getFieldList(); + for (int nodeIndex = 0; nodeIndex < inputs.size(); nodeIndex++) { + RelNode input = inputs.get(nodeIndex); + List projects = new ArrayList<>(); + List names = new ArrayList<>(); + for (int i = 0; i < input.getRowType().getFieldCount(); i++) { + projects.add( + new RexInputRef( + i, + fieldList.get(fieldIndex).getType())); + names.add(fieldList.get(fieldIndex).getName()); + fieldIndex++; + } + + builder.push(input) + .project(projects, names, true); + replaceInput(nodeIndex, builder.build()); + } + + RexShuttle inputRef2inputNameShuttle = new RexShuttle() { + @Override public RexNode visitInputRef(RexInputRef inputRef) { + int index = inputRef.getIndex(); + return new RexInputFieldName( + fieldList.get(index).getName(), + fieldList.get(index).getType()); + } + }; + + for (int i = 0; i < edges.size(); i++) { + HyperEdge edge = edges.get(i); + RexNode convertCond = edge.getCondition().accept(inputRef2inputNameShuttle); + HyperEdge convertEdge = + new HyperEdge( + edge.getLeftNodeBitmap(), + edge.getRightNodeBitmap(), + edge.getJoinType(), + convertCond); + edges.set(i, convertEdge); + } + } + + /** + * Adjusting RexInputRef in enumeration process is too complicated, + * so use unique name replace input ref. + * Before starting enumeration, convert RexInputRef to RexInputFieldName. + * When connect csgcmp to Join, convert RexInputFieldName to RexInputRef. + */ + private static class RexInputFieldName extends RexVariable { + + RexInputFieldName(final String fieldName, final RelDataType type) { + super(fieldName, type); + } + + @Override public R accept(RexVisitor visitor) { + throw new UnsupportedOperationException(); + } + + @Override public R accept(RexBiVisitor visitor, P arg) { + throw new UnsupportedOperationException(); + } + + @Override public boolean equals(@Nullable Object obj) { + return this == obj + || obj instanceof RexInputFieldName + && name == ((RexInputFieldName) obj).name + && type.equals(((RexInputFieldName) obj).type); + } + + @Override public int hashCode() { + return name.hashCode(); + } + + @Override public String toString() { + return name; + } + } +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/JoinToHyperGraphRule.java b/core/src/main/java/org/apache/calcite/rel/rules/JoinToHyperGraphRule.java new file mode 100644 index 00000000000..3ed41e4d6e7 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/JoinToHyperGraphRule.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.linq4j.function.Experimental; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitorImpl; + +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Collectors; + +/** Rule that flattens a tree of {@link LogicalJoin}s + * into a single {@link HyperGraph} with N inputs. + * + * @see CoreRules#JOIN_TO_HYPER_GRAPH + */ +@Value.Enclosing +@Experimental +public class JoinToHyperGraphRule + extends RelRule + implements TransformationRule { + + protected JoinToHyperGraphRule(Config config) { + super(config); + } + + @Override public void onMatch(RelOptRuleCall call) { + final Join origJoin = call.rel(0); + final RelNode left = call.rel(1); + final RelNode right = call.rel(2); + if (origJoin.getJoinType() != JoinRelType.INNER) { + return; + } + + HyperGraph result; + List inputs = new ArrayList<>(); + List edges = new ArrayList<>(); + List joinConds = new ArrayList<>(); + + if (origJoin.getCondition().isAlwaysTrue()) { + joinConds.add(origJoin.getCondition()); + } else { + RelOptUtil.decomposeConjunction(origJoin.getCondition(), joinConds); + } + + // when right is HyperGraph, need shift the leftNodeBit, rightNodeBit, condition of HyperEdge + int leftNodeCount; + int leftFieldCount = left.getRowType().getFieldCount(); + if (left instanceof HyperGraph && right instanceof HyperGraph) { + leftNodeCount = left.getInputs().size(); + inputs.addAll(left.getInputs()); + inputs.addAll(right.getInputs()); + + edges.addAll(((HyperGraph) left).getEdges()); + edges.addAll( + ((HyperGraph) right).getEdges().stream() + .map(hyperEdge -> adjustNodeBit(hyperEdge, leftNodeCount, leftFieldCount)) + .collect(Collectors.toList())); + } else if (left instanceof HyperGraph) { + leftNodeCount = left.getInputs().size(); + inputs.addAll(left.getInputs()); + inputs.add(right); + + edges.addAll(((HyperGraph) left).getEdges()); + } else if (right instanceof HyperGraph) { + leftNodeCount = 1; + inputs.add(left); + inputs.addAll(right.getInputs()); + + edges.addAll( + ((HyperGraph) right).getEdges().stream() + .map(hyperEdge -> adjustNodeBit(hyperEdge, leftNodeCount, leftFieldCount)) + .collect(Collectors.toList())); + } else { + leftNodeCount = 1; + inputs.add(left); + inputs.add(right); + } + + HashMap fieldIndexToNodeIndexMap = new HashMap<>(); + int fieldCount = 0; + for (int i = 0; i < inputs.size(); i++) { + for (int j = 0; j < inputs.get(i).getRowType().getFieldCount(); j++) { + fieldIndexToNodeIndexMap.put(fieldCount++, i); + } + } + // convert current join condition to hyper edge condition + for (RexNode joinCond : joinConds) { + long leftNodeBits; + long rightNodeBits; + List leftRefs = new ArrayList<>(); + List rightRefs = new ArrayList<>(); + + RexVisitorImpl visitor = new RexVisitorImpl(true) { + @Override public Void visitInputRef(RexInputRef inputRef) { + Integer nodeIndex = fieldIndexToNodeIndexMap.get(inputRef.getIndex()); + if (nodeIndex == null) { + throw new IllegalArgumentException("RexInputRef refers a dummy field: " + + inputRef + ", rowType is: " + origJoin.getRowType()); + } + if (nodeIndex < leftNodeCount) { + leftRefs.add(nodeIndex); + } else { + rightRefs.add(nodeIndex); + } + return null; + } + }; + joinCond.accept(visitor); + + // when cartesian product, make it to complex hyper edge + if (leftRefs.isEmpty() || rightRefs.isEmpty()) { + leftNodeBits = LongBitmap.newBitmapBetween(0, leftNodeCount); + rightNodeBits = LongBitmap.newBitmapBetween(leftNodeCount, inputs.size()); + } else { + leftNodeBits = LongBitmap.newBitmapFromList(leftRefs); + rightNodeBits = LongBitmap.newBitmapFromList(rightRefs); + } + edges.add( + new HyperEdge( + leftNodeBits, + rightNodeBits, + origJoin.getJoinType(), + joinCond)); + } + result = + new HyperGraph( + origJoin.getCluster(), + origJoin.getTraitSet(), + inputs, + edges, + origJoin.getRowType()); + + call.transformTo(result); + } + + private static HyperEdge adjustNodeBit(HyperEdge hyperEdge, int nodeOffset, int fieldOffset) { + RexNode newCondition = RexUtil.shift(hyperEdge.getCondition(), fieldOffset); + return new HyperEdge( + hyperEdge.getLeftNodeBitmap() << nodeOffset, + hyperEdge.getRightNodeBitmap() << nodeOffset, + hyperEdge.getJoinType(), + newCondition); + } + + /** Rule configuration. */ + @Value.Immutable + public interface Config extends RelRule.Config { + Config DEFAULT = ImmutableJoinToHyperGraphRule.Config.of() + .withOperandSupplier(b1 -> + b1.operand(Join.class).inputs( + b2 -> b2.operand(RelNode.class).anyInputs(), + b3 -> b3.operand(RelNode.class).anyInputs())); + + @Override default JoinToHyperGraphRule toRule() { + return new JoinToHyperGraphRule(this); + } + } + +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/LongBitmap.java b/core/src/main/java/org/apache/calcite/rel/rules/LongBitmap.java new file mode 100644 index 00000000000..39ca3ec586c --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/LongBitmap.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * Bitmap tool for dphyp. + */ +public class LongBitmap { + + private LongBitmap() {} + + public static long newBitmapBetween(int startInclude, int endExclude) { + long bitmap = 0; + for (int i = startInclude; i < endExclude; i++) { + bitmap |= 1L << i; + } + return bitmap; + } + + public static long newBitmap(int value) { + return 1L << value; + } + + /** + * Corresponding to Bv = {node|node ≺ csg} in "Dynamic programming strikes back". + */ + public static long getBvBitmap(long csg) { + return (csg & -csg) - 1; + } + + public static boolean isSubSet(long maySub, long bigger) { + return (bigger | maySub) == bigger; + } + + public static boolean isOverlap(long bitmap1, long bitmap2) { + return (bitmap1 & bitmap2) != 0; + } + + public static long newBitmapFromList(List values) { + long bitmap = 0; + for (int value : values) { + bitmap |= 1L << value; + } + return bitmap; + } + + public static String printBitmap(long bitmap) { + StringBuilder sb = new StringBuilder(); + sb.append("{"); + while (bitmap != 0) { + sb.append(Long.numberOfTrailingZeros(bitmap)).append(", "); + bitmap = bitmap & (bitmap - 1); + } + sb.delete(sb.length() - 2, sb.length()); + sb.append("}"); + return sb.toString(); + } + + /** + * Traverse the bitmap in reverse order. + */ + public static class ReverseIterator implements Iterable { + + private long bitmap; + + public ReverseIterator(long bitmap) { + this.bitmap = bitmap; + } + + @Override public Iterator iterator() { + return new Iterator() { + @Override public boolean hasNext() { + return bitmap != 0; + } + + @Override public Long next() { + long res = Long.highestOneBit(bitmap); + bitmap &= ~res; + return res; + } + }; + } + } + + /** + * Enumerate all subsets of a bitmap from small to large. + */ + public static class SubsetIterator implements Iterable { + + private ArrayList subsetList; + + private int index; + + public SubsetIterator(long bitmap) { + long curBiggestSubset = bitmap; + this.subsetList = new ArrayList<>(); + + while (curBiggestSubset != 0) { + subsetList.add(curBiggestSubset); + curBiggestSubset = (curBiggestSubset - 1) & bitmap; + } + + this.index = subsetList.size() - 1; + } + + @Override public Iterator iterator() { + return new Iterator() { + @Override public boolean hasNext() { + return index >= 0; + } + + @Override public Long next() { + return subsetList.get(index--); + } + }; + } + + public void reset() { + index = subsetList.size() - 1; + } + } + +} diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index baf8fd71832..df40dd04a42 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -9735,4 +9735,59 @@ private void checkJoinAssociateRuleWithTopAlwaysTrueCondition(boolean allowAlway fixture().withRelBuilderConfig(a -> a.withBloat(-1)) .relFn(relFn).withPlanner(planner).check(); } + + @Test void testChainJoinDphypJoinReorder() { + HepProgram program = new HepProgramBuilder() + .addMatchOrder(HepMatchOrder.BOTTOM_UP) + .addRuleInstance(CoreRules.FILTER_INTO_JOIN) + .addRuleInstance(CoreRules.JOIN_TO_HYPER_GRAPH) + .build(); + + sql("select emp.empno from " + + "emp, emp_address, dept, dept_nested " + + "where emp.deptno + emp_address.empno = dept.deptno + dept_nested.deptno " + + "and emp.empno = emp_address.empno " + + "and dept.deptno = dept_nested.deptno") + .withPre(program) + .withRule(CoreRules.HYPER_GRAPH_OPTIMIZE, CoreRules.PROJECT_REMOVE, CoreRules.PROJECT_MERGE) + .check(); + } + + @Test void testStarJoinDphypJoinReorder() { + HepProgram program = new HepProgramBuilder() + .addMatchOrder(HepMatchOrder.BOTTOM_UP) + .addRuleInstance(CoreRules.FILTER_INTO_JOIN) + .addRuleInstance(CoreRules.JOIN_TO_HYPER_GRAPH) + .build(); + + sql("select emp.empno from " + + "emp, emp_b, emp_address, dept, dept_nested " + + "where emp.empno = emp_b.empno " + + "and emp.empno = emp_address.empno " + + "and emp.deptno = dept.deptno " + + "and emp.deptno = dept_nested.deptno " + + "and emp_b.sal + emp_address.empno = dept.deptno + dept_nested.deptno") + .withPre(program) + .withRule(CoreRules.HYPER_GRAPH_OPTIMIZE, CoreRules.PROJECT_REMOVE, CoreRules.PROJECT_MERGE) + .check(); + } + + @Test void testCycleJoinDphypJoinReorder() { + HepProgram program = new HepProgramBuilder() + .addMatchOrder(HepMatchOrder.BOTTOM_UP) + .addRuleInstance(CoreRules.FILTER_INTO_JOIN) + .addRuleInstance(CoreRules.JOIN_TO_HYPER_GRAPH) + .build(); + + sql("select emp.empno from " + + "emp, emp_b, dept, dept_nested " + + "where emp.empno = emp_b.empno " + + "and emp_b.deptno = dept.deptno " + + "and dept.name = dept_nested.name " + + "and dept_nested.deptno = emp.deptno " + + "and emp.sal + emp_b.sal = dept.deptno + dept_nested.deptno") + .withPre(program) + .withRule(CoreRules.HYPER_GRAPH_OPTIMIZE, CoreRules.PROJECT_REMOVE, CoreRules.PROJECT_MERGE) + .check(); + } } diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index 153864230c5..12bc4eace8a 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -1644,6 +1644,33 @@ case when cast(ename as double) < 5 then 0.0 ($1, 'abc'), $1, null:VARCHAR(20))):DOUBLE, 5.0E0), 0.0E0:DOUBLE, CASE(IS NOT NULL(CAST(CASE(>($1, 'abc'), $1, null:VARCHAR(20))):DOUBLE), CAST(CAST(CASE(>($1, 'abc'), $1, null:VARCHAR(20))):DOUBLE):DOUBLE NOT NULL, 1.0E0:DOUBLE))]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + + + + + + + @@ -2136,6 +2163,33 @@ and e1.deptno < 10 and d1.deptno < 15 and e1.sal > (select avg(sal) from emp e2 where e1.empno = e2.empno)]]> + + + + + + + + + + + + + + + + + + + + + +