diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java index b88e1459b2ea..72388a4f619d 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java @@ -816,4 +816,14 @@ private CoreRules() {} WINDOW_REDUCE_EXPRESSIONS = ReduceExpressionsRule.WindowReduceExpressionsRule.WindowReduceExpressionsRuleConfig .DEFAULT.toRule(); + /** Rule that flattens a tree of {@link LogicalJoin}s + * into a single {@link HyperGraph} with N inputs. */ + public static final JoinToHyperGraphRule JOIN_TO_HYPER_GRAPH = + JoinToHyperGraphRule.Config.DEFAULT.toRule(); + + /** Rule that re-orders a {@link Join} tree using dphyp algorithm. + * + * @see #JOIN_TO_HYPER_GRAPH */ + public static final DphypJoinReorderRule HYPER_GRAPH_OPTIMIZE = + DphypJoinReorderRule.Config.DEFAULT.toRule(); } diff --git a/core/src/main/java/org/apache/calcite/rel/rules/DpHyp.java b/core/src/main/java/org/apache/calcite/rel/rules/DpHyp.java new file mode 100644 index 000000000000..3225c7eaf3a0 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/DpHyp.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.tools.RelBuilder; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.util.HashMap; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; + +/** + * The core process of dphyp enumeration algorithm. + */ +public class DpHyp { + + private HyperGraph hyperGraph; + + private HashMap dpTable; + + private RelBuilder builder; + + private RelMetadataQuery mq; + + public DpHyp(HyperGraph hyperGraph, RelBuilder builder, RelMetadataQuery relMetadataQuery) { + this.hyperGraph = hyperGraph; + this.dpTable = new HashMap<>(); + this.builder = builder; + this.mq = relMetadataQuery; + } + + public void startEnumerateJoin() { + int size = hyperGraph.getInputs().size(); + for (int i = 0; i < size; i++) { + long singleNode = LongBitmap.newBitmap(i); + dpTable.put(singleNode, hyperGraph.getInput(i)); + hyperGraph.initEdgeBitMap(singleNode); + } + + // start enumerating from the second to last + for (int i = size - 2; i >= 0; i--) { + long csg = LongBitmap.newBitmap(i); + long forbidden = csg - 1; + emitCsg(csg); + enumerateCsgRec(csg, forbidden); + } + } + + private void emitCsg(long csg) { + long forbidden = csg | LongBitmap.getBvBitmap(csg); + long neighbors = hyperGraph.getNeighborBitmap(csg, forbidden); + + LongBitmap.ReverseIterator reverseIterator = new LongBitmap.ReverseIterator(neighbors); + for (long cmp : reverseIterator) { + List edges = hyperGraph.connectCsgCmp(csg, cmp); + if (!edges.isEmpty()) { + emitCsgCmp(csg, cmp, edges); + } + // forbidden the nodes that smaller than current cmp when extend cmp, e.g. + // neighbors = {t1, t2}, t1 and t2 are connected. + // when extented t2, we will get (t1, t2) + // when extented t1, we will get (t1, t2) repeated + long newForbidden = + (cmp | LongBitmap.getBvBitmap(cmp)) & neighbors; + newForbidden = newForbidden | forbidden; + enumerateCmpRec(csg, cmp, newForbidden); + } + } + + private void enumerateCsgRec(long csg, long forbidden) { + long neighbors = hyperGraph.getNeighborBitmap(csg, forbidden); + LongBitmap.SubsetIterator subsetIterator = new LongBitmap.SubsetIterator(neighbors); + for (long subNeighbor : subsetIterator) { + hyperGraph.updateEdgesForUnion(csg, subNeighbor); + long newCsg = csg | subNeighbor; + if (dpTable.containsKey(newCsg)) { + emitCsg(newCsg); + } + } + long newForbidden = forbidden | neighbors; + subsetIterator.reset(); + for (long subNeighbor : subsetIterator) { + long newCsg = csg | subNeighbor; + enumerateCsgRec(newCsg, newForbidden); + } + } + + private void enumerateCmpRec(long csg, long cmp, long forbidden) { + long neighbors = hyperGraph.getNeighborBitmap(cmp, forbidden); + LongBitmap.SubsetIterator subsetIterator = new LongBitmap.SubsetIterator(neighbors); + for (long subNeighbor : subsetIterator) { + long newCmp = cmp | subNeighbor; + hyperGraph.updateEdgesForUnion(cmp, subNeighbor); + if (dpTable.containsKey(newCmp)) { + List edges = hyperGraph.connectCsgCmp(csg, newCmp); + if (!edges.isEmpty()) { + emitCsgCmp(csg, newCmp, edges); + } + } + } + long newForbidden = forbidden | neighbors; + subsetIterator.reset(); + for (long subNeighbor : subsetIterator) { + long newCmp = cmp | subNeighbor; + enumerateCmpRec(csg, newCmp, newForbidden); + } + } + + private void emitCsgCmp(long csg, long cmp, List edges) { + checkArgument(dpTable.containsKey(csg)); + checkArgument(dpTable.containsKey(cmp)); + RelNode child1 = dpTable.get(csg); + RelNode child2 = dpTable.get(cmp); + + JoinRelType joinType = hyperGraph.extractJoinType(edges); + if (joinType == null) { + return; + } + RexNode joinCond1 = hyperGraph.extractJoinCond(child1, child2, edges); + RelNode newPlan1 = builder + .push(child1) + .push(child2) + .join(joinType, joinCond1) + .build(); + + // swap left and right + RexNode joinCond2 = hyperGraph.extractJoinCond(child2, child1, edges); + RelNode newPlan2 = builder + .push(child2) + .push(child1) + .join(joinType, joinCond2) + .build(); + RelNode winPlan = mq.getCumulativeCost(newPlan1).isLt(mq.getCumulativeCost(newPlan2)) + ? newPlan1 + : newPlan2; + + if (!dpTable.containsKey(csg | cmp)) { + dpTable.put(csg | cmp, winPlan); + } else { + RelNode oriPlan = dpTable.get(csg | cmp); + if (mq.getCumulativeCost(winPlan).isLt(mq.getCumulativeCost(oriPlan))) { + dpTable.put(csg | cmp, winPlan); + } + } + } + + public @Nullable RelNode getBestPlan() { + int size = hyperGraph.getInputs().size(); + long wholeGraph = LongBitmap.newBitmapBetween(0, size); + return dpTable.get(wholeGraph); + } + +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/DphypJoinReorderRule.java b/core/src/main/java/org/apache/calcite/rel/rules/DphypJoinReorderRule.java new file mode 100644 index 000000000000..ce640bffedee --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/DphypJoinReorderRule.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.tools.RelBuilder; + +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.List; + +/** Rule that re-orders a {@link Join} tree using dphyp algorithm. + * + * @see CoreRules#HYPER_GRAPH_OPTIMIZE */ +@Value.Enclosing +public class DphypJoinReorderRule + extends RelRule + implements TransformationRule { + + protected DphypJoinReorderRule(Config config) { + super(config); + } + + @Override public void onMatch(RelOptRuleCall call) { + HyperGraph hyperGraph = call.rel(0); + RelBuilder relBuilder = call.builder(); + // make all field name unique and convert the + // HyperEdge condition from RexInputRef to RexInputFieldName + hyperGraph.convertHyperEdgeCond(); + + // enumerate by Dphyp + DpHyp dpHyp = new DpHyp(hyperGraph, relBuilder, call.getMetadataQuery()); + dpHyp.startEnumerateJoin(); + RelNode orderedJoin = dpHyp.getBestPlan(); + if (orderedJoin == null) { + return; + } + + // permute field to origin order + List oriNames = hyperGraph.getRowType().getFieldNames(); + List newNames = orderedJoin.getRowType().getFieldNames(); + List projects = new ArrayList<>(); + RexBuilder rexBuilder = hyperGraph.getCluster().getRexBuilder(); + for (String oriName : oriNames) { + projects.add(rexBuilder.makeInputRef(orderedJoin, newNames.indexOf(oriName))); + } + + RelNode result = call.builder() + .push(orderedJoin) + .project(projects) + .build(); + call.transformTo(result); + } + + /** Rule configuration. */ + @Value.Immutable + public interface Config extends RelRule.Config { + Config DEFAULT = ImmutableDphypJoinReorderRule.Config.of() + .withOperandSupplier(b1 -> + b1.operand(HyperGraph.class).anyInputs()); + + @Override default DphypJoinReorderRule toRule() { + return new DphypJoinReorderRule(this); + } + } +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/HyperEdge.java b/core/src/main/java/org/apache/calcite/rel/rules/HyperEdge.java new file mode 100644 index 000000000000..167971ccc238 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/HyperEdge.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; + +/** + * Edge in HyperGraph, that represents a join predicate. + */ +public class HyperEdge { + + private long leftNodeBits; + + private long rightNodeBits; + + private JoinRelType joinType; + + private RexNode condition; + + public HyperEdge(long leftNodeBits, long rightNodeBits, JoinRelType joinType, RexNode condition) { + this.leftNodeBits = leftNodeBits; + this.rightNodeBits = rightNodeBits; + this.joinType = joinType; + this.condition = condition; + } + + public long getNodeBitmap() { + return leftNodeBits | rightNodeBits; + } + + public long getLeftNodeBitmap() { + return leftNodeBits; + } + + public long getRightNodeBitmap() { + return rightNodeBits; + } + + // hyperedge (u, v) is simple if |u| = |v| = 1 + public boolean isSimple() { + boolean leftSimple = (leftNodeBits & (leftNodeBits - 1)) == 0; + boolean rightSimple = (rightNodeBits & (rightNodeBits - 1)) == 0; + return leftSimple && rightSimple; + } + + public JoinRelType getJoinType() { + return joinType; + } + + public RexNode getCondition() { + return condition; + } + + @Override public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(LongBitmap.printBitmap(leftNodeBits)) + .append("——").append(joinType).append("——") + .append(LongBitmap.printBitmap(rightNodeBits)); + return sb.toString(); + } + + // before starting dphyp, replace RexInputRef to RexInputFieldName + public void replaceCondition(RexNode fieldNameCond) { + this.condition = fieldNameCond; + } + +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/HyperGraph.java b/core/src/main/java/org/apache/calcite/rel/rules/HyperGraph.java new file mode 100644 index 000000000000..089553c3c481 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/HyperGraph.java @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.AbstractRelNode; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBiVisitor; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVariable; +import org.apache.calcite.rex.RexVisitor; + +import com.google.common.collect.ImmutableList; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; + +/** + * HyperGraph represents a join graph. + */ +public class HyperGraph extends AbstractRelNode { + + private List inputs; + + private List edges; + + // key is the bitmap for inputs, value is the hyper edge bitmap in edges + private HashMap simpleEdgesMap; + + private HashMap complexEdgesMap; + + // node bitmap overlaps edge's leftNodeBits or rightNodeBits, but does not completely cover + private HashMap overlapEdgesMap; + + protected HyperGraph(RelOptCluster cluster, + RelTraitSet traitSet, + List inputs, + List edges, + RelDataType rowType) { + super(cluster, traitSet); + checkArgument(rowType != null); + this.inputs = inputs; + this.edges = edges; + this.rowType = rowType; + this.simpleEdgesMap = new HashMap<>(); + this.complexEdgesMap = new HashMap<>(); + this.overlapEdgesMap = new HashMap<>(); + } + + protected HyperGraph(RelOptCluster cluster, + RelTraitSet traitSet, + List inputs, + List edges, + RelDataType rowType, + HashMap simpleEdgesMap, + HashMap complexEdgesMap, + HashMap overlapEdgesMap) { + super(cluster, traitSet); + checkArgument(rowType != null); + this.inputs = inputs; + this.edges = edges; + this.rowType = rowType; + this.simpleEdgesMap = simpleEdgesMap; + this.complexEdgesMap = complexEdgesMap; + this.overlapEdgesMap = overlapEdgesMap; + } + + @Override public RelNode copy(RelTraitSet traitSet, List inputs) { + return new HyperGraph( + getCluster(), + traitSet, + inputs, + edges, + rowType, + simpleEdgesMap, + complexEdgesMap, + overlapEdgesMap); + } + + @Override public RelWriter explainTerms(RelWriter pw) { + super.explainTerms(pw); + for (Ord ord : Ord.zip(inputs)) { + pw.input("input#" + ord.i, ord.e); + } + List hyperEdges = edges.stream() + .map(hyperEdge -> hyperEdge.toString()) + .collect(Collectors.toList()); + pw.item("edges", String.join(",", hyperEdges)); + return pw; + } + + @Override public List getInputs() { + return inputs; + } + + @Override public void replaceInput(int ordinalInParent, RelNode p) { + inputs.set(ordinalInParent, p); + recomputeDigest(); + } + + @Override public RelDataType deriveRowType() { + return rowType; + } + + //~ hyper graph method ---------------------------------------------------------- + + public List getEdges() { + return edges; + } + + public long getNeighborBitmap(long csg, long forbidden) { + long neighbors = 0L; + List simpleEdges = simpleEdgesMap.getOrDefault(csg, new BitSet()).stream() + .mapToObj(edges::get) + .collect(Collectors.toList()); + for (HyperEdge edge : simpleEdges) { + neighbors |= edge.getNodeBitmap(); + } + + forbidden = forbidden | csg; + neighbors = neighbors & ~forbidden; + forbidden = forbidden | neighbors; + + List complexEdges = complexEdgesMap.getOrDefault(csg, new BitSet()).stream() + .mapToObj(edges::get) + .collect(Collectors.toList()); + for (HyperEdge edge : complexEdges) { + long leftBitmap = edge.getLeftNodeBitmap(); + long rightBitmap = edge.getRightNodeBitmap(); + if (LongBitmap.isSubSet(leftBitmap, csg) && !LongBitmap.isOverlap(rightBitmap, forbidden)) { + neighbors |= Long.lowestOneBit(rightBitmap); + } else if (LongBitmap.isSubSet(rightBitmap, csg) + && !LongBitmap.isOverlap(leftBitmap, forbidden)) { + neighbors |= Long.lowestOneBit(leftBitmap); + } + } + return neighbors; + } + + public List connectCsgCmp(long csg, long cmp) { + checkArgument(simpleEdgesMap.containsKey(csg)); + checkArgument(simpleEdgesMap.containsKey(cmp)); + List connectedEdges = new ArrayList<>(); + BitSet connectedEdgesBitmap = new BitSet(); + connectedEdgesBitmap.or(simpleEdgesMap.get(csg)); + connectedEdgesBitmap.or(complexEdgesMap.get(csg)); + connectedEdgesBitmap.or(overlapEdgesMap.get(csg)); + + BitSet cmpEdgesBitmap = new BitSet(); + cmpEdgesBitmap.or(simpleEdgesMap.get(cmp)); + cmpEdgesBitmap.or(complexEdgesMap.get(cmp)); + cmpEdgesBitmap.or(overlapEdgesMap.get(cmp)); + + connectedEdgesBitmap.and(cmpEdgesBitmap); + connectedEdgesBitmap.stream() + .filter(index -> LongBitmap.isSubSet(edges.get(index).getNodeBitmap(), csg | cmp)) + .forEach(index -> connectedEdges.add(edges.get(index))); + + return connectedEdges; + } + + public void initEdgeBitMap(long subset) { + BitSet simpleBitSet = new BitSet(); + BitSet complexBitSet = new BitSet(); + BitSet overlapBitSet = new BitSet(); + for (int i = 0; i < edges.size(); i++) { + HyperEdge edge = edges.get(i); + if (isAccurateEdge(edge, subset)) { + if (edge.isSimple()) { + simpleBitSet.set(i); + } else { + complexBitSet.set(i); + } + } else if (isOverlapEdge(edge, subset)) { + overlapBitSet.set(i); + } + } + simpleEdgesMap.put(subset, simpleBitSet); + complexEdgesMap.put(subset, complexBitSet); + overlapEdgesMap.put(subset, overlapBitSet); + } + + public void updateEdgesForUnion(long subset1, long subset2) { + if (!simpleEdgesMap.containsKey(subset1)) { + initEdgeBitMap(subset1); + } + if (!simpleEdgesMap.containsKey(subset2)) { + initEdgeBitMap(subset2); + } + long unionSet = subset1 | subset2; + if (simpleEdgesMap.containsKey(unionSet)) { + return; + } + + BitSet unionSimpleBitSet = new BitSet(); + unionSimpleBitSet.or(simpleEdgesMap.get(subset1)); + unionSimpleBitSet.or(simpleEdgesMap.get(subset2)); + + BitSet unionComplexBitSet = new BitSet(); + unionComplexBitSet.or(complexEdgesMap.get(subset1)); + unionComplexBitSet.or(complexEdgesMap.get(subset2)); + + BitSet unionOverlapBitSet = new BitSet(); + unionOverlapBitSet.or(overlapEdgesMap.get(subset1)); + unionOverlapBitSet.or(overlapEdgesMap.get(subset2)); + + // the overlaps edge that belongs to subset1/subset2 + // may be complex edge for subset1 union subset2 + for (int index : unionOverlapBitSet.stream().toArray()) { + HyperEdge edge = edges.get(index); + if (isAccurateEdge(edge, unionSet)) { + unionComplexBitSet.set(index); + unionOverlapBitSet.set(index, false); + } + } + + // remove cycle in subset1 union subset2 + for (int index : unionSimpleBitSet.stream().toArray()) { + HyperEdge edge = edges.get(index); + if (!isAccurateEdge(edge, unionSet)) { + unionSimpleBitSet.set(index, false); + } + } + for (int index : unionComplexBitSet.stream().toArray()) { + HyperEdge edge = edges.get(index); + if (!isAccurateEdge(edge, unionSet)) { + unionComplexBitSet.set(index, false); + } + } + + simpleEdgesMap.put(unionSet, unionSimpleBitSet); + complexEdgesMap.put(unionSet, unionComplexBitSet); + overlapEdgesMap.put(unionSet, unionOverlapBitSet); + } + + private static boolean isAccurateEdge(HyperEdge edge, long subset) { + boolean isLeftEnd = LongBitmap.isSubSet(edge.getLeftNodeBitmap(), subset) + && !LongBitmap.isOverlap(edge.getRightNodeBitmap(), subset); + boolean isRightEnd = LongBitmap.isSubSet(edge.getRightNodeBitmap(), subset) + && !LongBitmap.isOverlap(edge.getLeftNodeBitmap(), subset); + return isLeftEnd || isRightEnd; + } + + private static boolean isOverlapEdge(HyperEdge edge, long subset) { + boolean isLeftEnd = LongBitmap.isOverlap(edge.getLeftNodeBitmap(), subset) + && !LongBitmap.isOverlap(edge.getRightNodeBitmap(), subset); + boolean isRightEnd = LongBitmap.isOverlap(edge.getRightNodeBitmap(), subset) + && !LongBitmap.isOverlap(edge.getLeftNodeBitmap(), subset); + return isLeftEnd || isRightEnd; + } + + public @Nullable JoinRelType extractJoinType(List edges) { + JoinRelType joinType = edges.get(0).getJoinType(); + for (int i = 1; i < edges.size(); i++) { + if (edges.get(i).getJoinType() != joinType) { + return null; + } + } + return joinType; + } + + public RexNode extractJoinCond(RelNode left, RelNode right, List edges) { + List joinConds = new ArrayList<>(); + List fieldList = new ArrayList<>(left.getRowType().getFieldList()); + fieldList.addAll(right.getRowType().getFieldList()); + + List names = new ArrayList<>(left.getRowType().getFieldNames()); + names.addAll(right.getRowType().getFieldNames()); + + // convert the HyperEdge's condition from RexInputFieldName to RexInputRef + RexShuttle inputName2InputRefShuttle = new RexShuttle() { + @Override protected List visitList( + List exprs, + boolean @Nullable [] update) { + ImmutableList.Builder clonedOperands = ImmutableList.builder(); + for (RexNode operand : exprs) { + RexNode clonedOperand; + if (operand instanceof RexInputFieldName) { + int index = names.indexOf(((RexInputFieldName) operand).getName()); + clonedOperand = new RexInputRef(index, fieldList.get(index).getType()); + } else { + clonedOperand = operand.accept(this); + } + if ((clonedOperand != operand) && (update != null)) { + update[0] = true; + } + clonedOperands.add(clonedOperand); + } + return clonedOperands.build(); + } + }; + + for (HyperEdge edge : edges) { + RexNode inputRefCond = edge.getCondition().accept(inputName2InputRefShuttle); + joinConds.add(inputRefCond); + } + return RexUtil.composeConjunction(left.getCluster().getRexBuilder(), joinConds); + } + + /** + * Before starting enumeration, add Project on every input, make all field name unique. + * Convert the HyperEdge condition from RexInputRef to RexInputFieldName + */ + public void convertHyperEdgeCond() { + int fieldIndex = 0; + List fieldList = rowType.getFieldList(); + for (int nodeIndex = 0; nodeIndex < inputs.size(); nodeIndex++) { + RelNode input = inputs.get(nodeIndex); + List projects = new ArrayList<>(); + List names = new ArrayList<>(); + for (int i = 0; i < input.getRowType().getFieldCount(); i++) { + projects.add( + new RexInputRef( + i, + fieldList.get(fieldIndex).getType())); + names.add(fieldList.get(fieldIndex).getName()); + fieldIndex++; + } + RelNode renameProject = LogicalProject.create( + input, + ImmutableList.of(), + projects, + names, + input.getVariablesSet()); + replaceInput(nodeIndex, renameProject); + } + + RexShuttle inputRef2inputNameShuttle = new RexShuttle() { + @Override public RexNode visitInputRef(RexInputRef inputRef) { + int index = inputRef.getIndex(); + return new RexInputFieldName( + fieldList.get(index).getName(), + fieldList.get(index).getType()); + } + }; + + for (HyperEdge hyperEdge : edges) { + RexNode convertCond = hyperEdge.getCondition().accept(inputRef2inputNameShuttle); + hyperEdge.replaceCondition(convertCond); + } + } + + /** + * Adjusting RexInputRef in enumeration process is too complicated, + * so use unique name replace input ref. + * Before starting enumeration, convert RexInputRef to RexInputFieldName. + * When connect csgcmp to Join, convert RexInputFieldName to RexInputRef. + */ + private static class RexInputFieldName extends RexVariable { + + RexInputFieldName(final String fieldName, final RelDataType type) { + super(fieldName, type); + } + + @Override public R accept(RexVisitor visitor) { + throw new UnsupportedOperationException(); + } + + @Override public R accept(RexBiVisitor visitor, P arg) { + throw new UnsupportedOperationException(); + } + + @Override public boolean equals(@Nullable Object obj) { + return this == obj + || obj instanceof RexInputFieldName + && name == ((RexInputFieldName) obj).name + && type.equals(((RexInputFieldName) obj).type); + } + + @Override public int hashCode() { + return name.hashCode(); + } + + @Override public String toString() { + return name; + } + } +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/JoinToHyperGraphRule.java b/core/src/main/java/org/apache/calcite/rel/rules/JoinToHyperGraphRule.java new file mode 100644 index 000000000000..aa13b68e8a79 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/JoinToHyperGraphRule.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexVisitorImpl; + +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; + +/** Rule that flattens a tree of {@link LogicalJoin}s + * into a single {@link HyperGraph} with N inputs. + * + * @see CoreRules#JOIN_TO_HYPER_GRAPH + */ +@Value.Enclosing +public class JoinToHyperGraphRule + extends RelRule + implements TransformationRule { + + protected JoinToHyperGraphRule(Config config) { + super(config); + } + + @Override public void onMatch(RelOptRuleCall call) { + final Join origJoin = call.rel(0); + final RelNode left = call.rel(1); + final RelNode right = call.rel(2); + if (origJoin.getJoinType() != JoinRelType.INNER) { + return; + } + + HyperGraph result; + List inputs = new ArrayList<>(); + List edges = new ArrayList<>(); + List joinConds = new ArrayList<>(); + + if (origJoin.getCondition().isAlwaysTrue()) { + joinConds.add(origJoin.getCondition()); + } else { + RelOptUtil.decomposeConjunction(origJoin.getCondition(), joinConds); + } + + // when right is HyperGraph, need shift the leftNodeBit, rightNodeBit, condition of HyperEdge + int leftNodeCount; + int leftFieldCount = left.getRowType().getFieldCount(); + if (left instanceof HyperGraph && right instanceof HyperGraph) { + leftNodeCount = left.getInputs().size(); + inputs.addAll(left.getInputs()); + inputs.addAll(right.getInputs()); + + edges.addAll(((HyperGraph) left).getEdges()); + edges.addAll( + ((HyperGraph) right).getEdges().stream() + .map(hyperEdge -> adjustNodeBit(hyperEdge, leftNodeCount, leftFieldCount)) + .collect(Collectors.toList())); + } else if (left instanceof HyperGraph) { + leftNodeCount = left.getInputs().size(); + inputs.addAll(left.getInputs()); + inputs.add(right); + + edges.addAll(((HyperGraph) left).getEdges()); + } else if (right instanceof HyperGraph) { + leftNodeCount = 1; + inputs.add(left); + inputs.addAll(right.getInputs()); + + edges.addAll( + ((HyperGraph) right).getEdges().stream() + .map(hyperEdge -> adjustNodeBit(hyperEdge, leftNodeCount, leftFieldCount)) + .collect(Collectors.toList())); + } else { + leftNodeCount = 1; + inputs.add(left); + inputs.add(right); + } + + HashMap fieldIndexToNodeIndexMap = new HashMap<>(); + int fieldCount = 0; + for (int i = 0; i < inputs.size(); i++) { + for (int j = 0; j < inputs.get(i).getRowType().getFieldCount(); j++) { + fieldIndexToNodeIndexMap.put(fieldCount++, i); + } + } + // convert current join condition to hyper edge condition + for (RexNode joinCond : joinConds) { + long leftNodeBits; + long rightNodeBits; + List leftRefs = new ArrayList<>(); + List rightRefs = new ArrayList<>(); + + RexVisitorImpl visitor = new RexVisitorImpl(true) { + @Override public Void visitInputRef(RexInputRef inputRef) { + checkArgument(fieldIndexToNodeIndexMap.containsKey(inputRef.getIndex())); + int nodeIndex = fieldIndexToNodeIndexMap.get(inputRef.getIndex()); + if (nodeIndex < leftNodeCount) { + leftRefs.add(nodeIndex); + } else { + rightRefs.add(nodeIndex); + } + return null; + } + }; + joinCond.accept(visitor); + + // when cartesian product, make it to complex hyper edge + if (leftRefs.isEmpty() || rightRefs.isEmpty()) { + leftNodeBits = LongBitmap.newBitmapBetween(0, leftNodeCount); + rightNodeBits = LongBitmap.newBitmapBetween(leftNodeCount, inputs.size()); + } else { + leftNodeBits = LongBitmap.newBitmapFromList(leftRefs); + rightNodeBits = LongBitmap.newBitmapFromList(rightRefs); + } + edges.add( + new HyperEdge( + leftNodeBits, + rightNodeBits, + origJoin.getJoinType(), + joinCond)); + } + result = + new HyperGraph(origJoin.getCluster(), + origJoin.getTraitSet(), + inputs, + edges, + origJoin.getRowType()); + + call.transformTo(result); + } + + private static HyperEdge adjustNodeBit(HyperEdge hyperEdge, int nodeOffset, int fieldOffset) { + RexShuttle shuttle = new RexShuttle() { + @Override public RexNode visitInputRef(RexInputRef inputRef) { + return new RexInputRef( + inputRef.getIndex() + fieldOffset, + inputRef.getType()); + } + }; + RexNode newCondition = hyperEdge.getCondition().accept(shuttle); + return new HyperEdge( + hyperEdge.getLeftNodeBitmap() << nodeOffset, + hyperEdge.getRightNodeBitmap() << nodeOffset, + hyperEdge.getJoinType(), + newCondition); + } + + /** Rule configuration. */ + @Value.Immutable + public interface Config extends RelRule.Config { + Config DEFAULT = ImmutableJoinToHyperGraphRule.Config.of() + .withOperandSupplier(b1 -> + b1.operand(Join.class).inputs( + b2 -> b2.operand(RelNode.class).anyInputs(), + b3 -> b3.operand(RelNode.class).anyInputs())); + + @Override default JoinToHyperGraphRule toRule() { + return new JoinToHyperGraphRule(this); + } + } + +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/LongBitmap.java b/core/src/main/java/org/apache/calcite/rel/rules/LongBitmap.java new file mode 100644 index 000000000000..39ca3ec586c8 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/LongBitmap.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * Bitmap tool for dphyp. + */ +public class LongBitmap { + + private LongBitmap() {} + + public static long newBitmapBetween(int startInclude, int endExclude) { + long bitmap = 0; + for (int i = startInclude; i < endExclude; i++) { + bitmap |= 1L << i; + } + return bitmap; + } + + public static long newBitmap(int value) { + return 1L << value; + } + + /** + * Corresponding to Bv = {node|node ≺ csg} in "Dynamic programming strikes back". + */ + public static long getBvBitmap(long csg) { + return (csg & -csg) - 1; + } + + public static boolean isSubSet(long maySub, long bigger) { + return (bigger | maySub) == bigger; + } + + public static boolean isOverlap(long bitmap1, long bitmap2) { + return (bitmap1 & bitmap2) != 0; + } + + public static long newBitmapFromList(List values) { + long bitmap = 0; + for (int value : values) { + bitmap |= 1L << value; + } + return bitmap; + } + + public static String printBitmap(long bitmap) { + StringBuilder sb = new StringBuilder(); + sb.append("{"); + while (bitmap != 0) { + sb.append(Long.numberOfTrailingZeros(bitmap)).append(", "); + bitmap = bitmap & (bitmap - 1); + } + sb.delete(sb.length() - 2, sb.length()); + sb.append("}"); + return sb.toString(); + } + + /** + * Traverse the bitmap in reverse order. + */ + public static class ReverseIterator implements Iterable { + + private long bitmap; + + public ReverseIterator(long bitmap) { + this.bitmap = bitmap; + } + + @Override public Iterator iterator() { + return new Iterator() { + @Override public boolean hasNext() { + return bitmap != 0; + } + + @Override public Long next() { + long res = Long.highestOneBit(bitmap); + bitmap &= ~res; + return res; + } + }; + } + } + + /** + * Enumerate all subsets of a bitmap from small to large. + */ + public static class SubsetIterator implements Iterable { + + private ArrayList subsetList; + + private int index; + + public SubsetIterator(long bitmap) { + long curBiggestSubset = bitmap; + this.subsetList = new ArrayList<>(); + + while (curBiggestSubset != 0) { + subsetList.add(curBiggestSubset); + curBiggestSubset = (curBiggestSubset - 1) & bitmap; + } + + this.index = subsetList.size() - 1; + } + + @Override public Iterator iterator() { + return new Iterator() { + @Override public boolean hasNext() { + return index >= 0; + } + + @Override public Long next() { + return subsetList.get(index--); + } + }; + } + + public void reset() { + index = subsetList.size() - 1; + } + } + +} diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index b02a929e675e..43f6cf2438b2 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -9546,4 +9546,21 @@ private void checkJoinAssociateRuleWithTopAlwaysTrueCondition(boolean allowAlway .withRule(CoreRules.MULTI_JOIN_OPTIMIZE) .check(); } + + @Test void testDphypJoinReorder() { + HepProgram program = new HepProgramBuilder() + .addMatchOrder(HepMatchOrder.BOTTOM_UP) + .addRuleInstance(CoreRules.FILTER_INTO_JOIN) + .addRuleInstance(CoreRules.JOIN_TO_HYPER_GRAPH) + .build(); + + sql("select emp.empno from " + + "emp, dept, emp_address, dept_nested " + + "where emp.deptno = dept.deptno " + + "and emp.empno = emp_address.empno " + + "and dept.deptno = dept_nested.deptno") + .withPre(program) + .withRule(CoreRules.HYPER_GRAPH_OPTIMIZE, CoreRules.PROJECT_REMOVE, CoreRules.PROJECT_MERGE) + .check(); + } } diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index 60c683ee50e6..1399e8436f6b 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -3189,6 +3189,33 @@ LogicalProject(EXPR$0=[/(1.0:DECIMAL(2, 1), 0.0E0:DOUBLE)]) + + + + + + + + + + +