From 640c191e8e57b4753119d26b0aa5a0c46ef3cb52 Mon Sep 17 00:00:00 2001 From: Ruben Quesada Lopez Date: Fri, 7 Mar 2025 20:09:06 +0000 Subject: [PATCH] [CALCITE-6788] LoptOptimizeJoinRule should be able to delegate costs to the planner --- .../rel/rules/LoptOptimizeJoinRule.java | 63 ++++++++++-------- .../apache/calcite/test/RelOptRulesTest.java | 66 +++++++++++++++++++ .../apache/calcite/test/RelOptRulesTest.xml | 56 ++++++++++++++++ 3 files changed, 157 insertions(+), 28 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java b/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java index 90a0c9be82b..16cb367ac2f 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java @@ -132,7 +132,7 @@ public LoptOptimizeJoinRule(RelFactories.JoinFactory joinFactory, findRemovableSelfJoins(mq, multiJoin); - findBestOrderings(mq, call.builder(), multiJoin, semiJoinOpt, call); + findBestOrderings(call, multiJoin, semiJoinOpt); } /** @@ -442,12 +442,10 @@ private static boolean isSelfJoinFilterUnique( * @param semiJoinOpt optimal semijoins for each factor * @param call RelOptRuleCall associated with this rule */ - private static void findBestOrderings( - RelMetadataQuery mq, - RelBuilder relBuilder, + private void findBestOrderings( + RelOptRuleCall call, LoptMultiJoin multiJoin, - LoptSemiJoinOptimizer semiJoinOpt, - RelOptRuleCall call) { + LoptSemiJoinOptimizer semiJoinOpt) { final List plans = new ArrayList<>(); final List fieldNames = @@ -461,8 +459,7 @@ private static void findBestOrderings( } LoptJoinTree joinTree = createOrdering( - mq, - relBuilder, + call, multiJoin, semiJoinOpt, i); @@ -679,9 +676,8 @@ private static void setFactorJoinKeys( * @return constructed join tree or null if it is not possible for * firstFactor to appear as the first factor in the join */ - private static @Nullable LoptJoinTree createOrdering( - RelMetadataQuery mq, - RelBuilder relBuilder, + private @Nullable LoptJoinTree createOrdering( + RelOptRuleCall call, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, int firstFactor) { @@ -712,7 +708,7 @@ private static void setFactorJoinKeys( } else { nextFactor = getBestNextFactor( - mq, + call.getMetadataQuery(), multiJoin, factorsToAdd, factorsAdded, @@ -733,8 +729,7 @@ private static void setFactorJoinKeys( factorsNeeded.and(factorsAdded); joinTree = addFactorToTree( - mq, - relBuilder, + call, multiJoin, semiJoinOpt, joinTree, @@ -878,9 +873,8 @@ private static boolean isJoinTree(RelNode rel) { * @return optimal join tree with the new factor added if it is possible to * add the factor; otherwise, null is returned */ - private static @Nullable LoptJoinTree addFactorToTree( - RelMetadataQuery mq, - RelBuilder relBuilder, + private @Nullable LoptJoinTree addFactorToTree( + RelOptRuleCall call, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, @Nullable LoptJoinTree joinTree, @@ -888,6 +882,8 @@ private static boolean isJoinTree(RelNode rel) { BitSet factorsNeeded, List filtersToAdd, boolean selfJoin) { + final RelMetadataQuery mq = call.getMetadataQuery(); + final RelBuilder relBuilder = call.builder(); // if the factor corresponds to the null generating factor in an outer // join that can be removed, then create a replacement join @@ -943,8 +939,7 @@ private static boolean isJoinTree(RelNode rel) { selfJoin); LoptJoinTree pushDownTree = pushDownFactor( - mq, - relBuilder, + call, multiJoin, semiJoinOpt, joinTree, @@ -959,10 +954,10 @@ private static boolean isJoinTree(RelNode rel) { RelOptCost costPushDown = null; RelOptCost costTop = null; if (pushDownTree != null) { - costPushDown = mq.getCumulativeCost(pushDownTree.getJoinTree()); + costPushDown = config.costFunction().getCost(call, pushDownTree.getJoinTree()); } if (topTree != null) { - costTop = mq.getCumulativeCost(topTree.getJoinTree()); + costTop = config.costFunction().getCost(call, topTree.getJoinTree()); } if (pushDownTree == null) { @@ -1035,9 +1030,8 @@ private static int rowWidthCost(RelNode tree) { * join tree if it is possible to do the pushdown; otherwise, null is * returned */ - private static @Nullable LoptJoinTree pushDownFactor( - RelMetadataQuery mq, - RelBuilder relBuilder, + private @Nullable LoptJoinTree pushDownFactor( + RelOptRuleCall call, LoptMultiJoin multiJoin, LoptSemiJoinOptimizer semiJoinOpt, LoptJoinTree joinTree, @@ -1110,8 +1104,7 @@ private static int rowWidthCost(RelNode tree) { LoptJoinTree subTree = (childNo == 0) ? left : right; subTree = addFactorToTree( - mq, - relBuilder, + call, multiJoin, semiJoinOpt, subTree, @@ -1165,8 +1158,8 @@ private static int rowWidthCost(RelNode tree) { // create the new join tree with the factor pushed down return createJoinSubtree( - mq, - relBuilder, + call.getMetadataQuery(), + call.builder(), multiJoin, left, right, @@ -2089,12 +2082,26 @@ private static boolean areSelfJoinKeysUnique(RelMetadataQuery mq, joinInfo.leftSet()); } + /** Function to compute cost. */ + @FunctionalInterface + public interface CostFunction { + @Nullable RelOptCost getCost(RelOptRuleCall call, RelNode relNode); + } + /** Rule configuration. */ @Value.Immutable public interface Config extends RelRule.Config { Config DEFAULT = ImmutableLoptOptimizeJoinRule.Config.of() .withOperandSupplier(b -> b.operand(MultiJoin.class).anyInputs()); + /** Function to calculate intermediate cost computations. */ + @Value.Default default CostFunction costFunction() { + return (call, rel) -> call.getMetadataQuery().getCumulativeCost(rel); + } + + /** Sets {@link #costFunction()}. */ + Config withCostFunction(CostFunction function); + @Override default LoptOptimizeJoinRule toRule() { return new LoptOptimizeJoinRule(this); } diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index 42d5f6149f5..b4f0e06a763 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -26,6 +26,7 @@ import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.plan.Contexts; import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptCost; import org.apache.calcite.plan.RelOptCostImpl; import org.apache.calcite.plan.RelOptPlanner; import org.apache.calcite.plan.RelOptRule; @@ -53,6 +54,7 @@ import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.Minus; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rel.core.Union; import org.apache.calcite.rel.hint.HintPredicates; import org.apache.calcite.rel.hint.HintStrategyTable; @@ -62,6 +64,7 @@ import org.apache.calcite.rel.logical.LogicalFilter; import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.logical.LogicalTableModify; +import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.rules.AggregateExpandWithinDistinctRule; import org.apache.calcite.rel.rules.AggregateExtractProjectRule; import org.apache.calcite.rel.rules.AggregateProjectConstantToDummyJoinRule; @@ -78,6 +81,7 @@ import org.apache.calcite.rel.rules.FilterProjectTransposeRule; import org.apache.calcite.rel.rules.JoinAssociateRule; import org.apache.calcite.rel.rules.JoinCommuteRule; +import org.apache.calcite.rel.rules.LoptOptimizeJoinRule; import org.apache.calcite.rel.rules.MeasureRules; import org.apache.calcite.rel.rules.MultiJoin; import org.apache.calcite.rel.rules.ProjectCorrelateTransposeRule; @@ -9733,6 +9737,68 @@ private void checkJoinAssociateRuleWithTopAlwaysTrueCondition(boolean allowAlway .check(); } + /** Test case for + * [CALCITE-6788] + * LoptOptimizeJoinRule should be able to delegate costs to the planner. */ + @Test void testLoptOptimizeJoinRuleWithDefaultCost() { + // Use the default rule + checkLoptOptimizeJoinRule(CoreRules.MULTI_JOIN_OPTIMIZE); + } + + /** Test case for + * [CALCITE-6788] + * LoptOptimizeJoinRule should be able to delegate costs to the planner. */ + @Test void testLoptOptimizeJoinRuleWithSpecialCost() { + // Use an ad-hoc version of the rule that uses planner#getCost instead of mq#getCumulativeCost + checkLoptOptimizeJoinRule(LoptOptimizeJoinRule.Config.DEFAULT + .withCostFunction((c, r) -> c.getPlanner().getCost(r, c.getMetadataQuery())) + .toRule()); + } + + private void checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) { + final HepProgram preProgram = new HepProgramBuilder() + .addMatchOrder(HepMatchOrder.BOTTOM_UP) + .addRuleInstance(CoreRules.JOIN_TO_MULTI_JOIN) + .build(); + + final HepProgram program = HepProgram.builder() + .addMatchOrder(HepMatchOrder.BOTTOM_UP) + .addRuleInstance(rule) + .build(); + + // Special planner that artificially favors joins on the same table + final HepPlanner planner = new HepPlanner(program) { + @Override public RelOptCost getCost(RelNode rel, RelMetadataQuery mq) { + if (rel instanceof Join + && rel.getInput(0).stripped() instanceof TableScan + && rel.getInput(1).stripped() instanceof TableScan) { + TableScan left = (TableScan) rel.getInput(0).stripped(); + TableScan right = (TableScan) rel.getInput(1).stripped(); + if (left.getTable().equals(right.getTable())) { + // Tiny cost for self-joins + return getCostFactory().makeTinyCost(); + } + } + + // General case: just define a kind of cumulative cost based on the rowCount (to avoid + // the infinite costs from the Logical operators) + RelOptCost cost = new RelOptCostImpl(mq.getRowCount(rel)); + for (RelNode input : rel.getInputs()) { + cost = cost.plus(getCost(input, mq)); + } + return cost; + } + }; + + sql("select e1.empno from emp e1" + + " inner join dept d1 on d1.deptno = e1.deptno" + + " inner join emp e2 on e1.ename = e2.ename" + + " inner join dept d2 on d2.deptno = e1.deptno") + .withPre(preProgram) + .withPlanner(planner) + .check(); + } + /** * Test case of * [CALCITE-6850] diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index ca64c718e6b..543a117e2c4 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -7419,6 +7419,62 @@ LogicalProject(EMPNO=[$0]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) +]]> + + + + + + + + + + + + + + + + + + + + + +