Skip to content

Commit

Permalink
[CALCITE-6788] LoptOptimizeJoinRule should be able to delegate costs …
Browse files Browse the repository at this point in the history
…to the planner
  • Loading branch information
rubenada committed Mar 8, 2025
1 parent 802fce3 commit 640c191
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public LoptOptimizeJoinRule(RelFactories.JoinFactory joinFactory,

findRemovableSelfJoins(mq, multiJoin);

findBestOrderings(mq, call.builder(), multiJoin, semiJoinOpt, call);
findBestOrderings(call, multiJoin, semiJoinOpt);
}

/**
Expand Down Expand Up @@ -442,12 +442,10 @@ private static boolean isSelfJoinFilterUnique(
* @param semiJoinOpt optimal semijoins for each factor
* @param call RelOptRuleCall associated with this rule
*/
private static void findBestOrderings(
RelMetadataQuery mq,
RelBuilder relBuilder,
private void findBestOrderings(
RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
RelOptRuleCall call) {
LoptSemiJoinOptimizer semiJoinOpt) {
final List<RelNode> plans = new ArrayList<>();

final List<String> fieldNames =
Expand All @@ -461,8 +459,7 @@ private static void findBestOrderings(
}
LoptJoinTree joinTree =
createOrdering(
mq,
relBuilder,
call,
multiJoin,
semiJoinOpt,
i);
Expand Down Expand Up @@ -679,9 +676,8 @@ private static void setFactorJoinKeys(
* @return constructed join tree or null if it is not possible for
* firstFactor to appear as the first factor in the join
*/
private static @Nullable LoptJoinTree createOrdering(
RelMetadataQuery mq,
RelBuilder relBuilder,
private @Nullable LoptJoinTree createOrdering(
RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
int firstFactor) {
Expand Down Expand Up @@ -712,7 +708,7 @@ private static void setFactorJoinKeys(
} else {
nextFactor =
getBestNextFactor(
mq,
call.getMetadataQuery(),
multiJoin,
factorsToAdd,
factorsAdded,
Expand All @@ -733,8 +729,7 @@ private static void setFactorJoinKeys(
factorsNeeded.and(factorsAdded);
joinTree =
addFactorToTree(
mq,
relBuilder,
call,
multiJoin,
semiJoinOpt,
joinTree,
Expand Down Expand Up @@ -878,16 +873,17 @@ private static boolean isJoinTree(RelNode rel) {
* @return optimal join tree with the new factor added if it is possible to
* add the factor; otherwise, null is returned
*/
private static @Nullable LoptJoinTree addFactorToTree(
RelMetadataQuery mq,
RelBuilder relBuilder,
private @Nullable LoptJoinTree addFactorToTree(
RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
@Nullable LoptJoinTree joinTree,
int factorToAdd,
BitSet factorsNeeded,
List<RexNode> filtersToAdd,
boolean selfJoin) {
final RelMetadataQuery mq = call.getMetadataQuery();
final RelBuilder relBuilder = call.builder();

// if the factor corresponds to the null generating factor in an outer
// join that can be removed, then create a replacement join
Expand Down Expand Up @@ -943,8 +939,7 @@ private static boolean isJoinTree(RelNode rel) {
selfJoin);
LoptJoinTree pushDownTree =
pushDownFactor(
mq,
relBuilder,
call,
multiJoin,
semiJoinOpt,
joinTree,
Expand All @@ -959,10 +954,10 @@ private static boolean isJoinTree(RelNode rel) {
RelOptCost costPushDown = null;
RelOptCost costTop = null;
if (pushDownTree != null) {
costPushDown = mq.getCumulativeCost(pushDownTree.getJoinTree());
costPushDown = config.costFunction().getCost(call, pushDownTree.getJoinTree());
}
if (topTree != null) {
costTop = mq.getCumulativeCost(topTree.getJoinTree());
costTop = config.costFunction().getCost(call, topTree.getJoinTree());
}

if (pushDownTree == null) {
Expand Down Expand Up @@ -1035,9 +1030,8 @@ private static int rowWidthCost(RelNode tree) {
* join tree if it is possible to do the pushdown; otherwise, null is
* returned
*/
private static @Nullable LoptJoinTree pushDownFactor(
RelMetadataQuery mq,
RelBuilder relBuilder,
private @Nullable LoptJoinTree pushDownFactor(
RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
LoptJoinTree joinTree,
Expand Down Expand Up @@ -1110,8 +1104,7 @@ private static int rowWidthCost(RelNode tree) {
LoptJoinTree subTree = (childNo == 0) ? left : right;
subTree =
addFactorToTree(
mq,
relBuilder,
call,
multiJoin,
semiJoinOpt,
subTree,
Expand Down Expand Up @@ -1165,8 +1158,8 @@ private static int rowWidthCost(RelNode tree) {

// create the new join tree with the factor pushed down
return createJoinSubtree(
mq,
relBuilder,
call.getMetadataQuery(),
call.builder(),
multiJoin,
left,
right,
Expand Down Expand Up @@ -2089,12 +2082,26 @@ private static boolean areSelfJoinKeysUnique(RelMetadataQuery mq,
joinInfo.leftSet());
}

/** Function to compute cost. */
@FunctionalInterface
public interface CostFunction {
@Nullable RelOptCost getCost(RelOptRuleCall call, RelNode relNode);
}

/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config DEFAULT = ImmutableLoptOptimizeJoinRule.Config.of()
.withOperandSupplier(b -> b.operand(MultiJoin.class).anyInputs());

/** Function to calculate intermediate cost computations. */
@Value.Default default CostFunction costFunction() {
return (call, rel) -> call.getMetadataQuery().getCumulativeCost(rel);
}

/** Sets {@link #costFunction()}. */
Config withCostFunction(CostFunction function);

@Override default LoptOptimizeJoinRule toRule() {
return new LoptOptimizeJoinRule(this);
}
Expand Down
66 changes: 66 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.calcite.config.CalciteConnectionConfig;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptCostImpl;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptRule;
Expand Down Expand Up @@ -53,6 +54,7 @@
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.hint.HintPredicates;
import org.apache.calcite.rel.hint.HintStrategyTable;
Expand All @@ -62,6 +64,7 @@
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalTableModify;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.AggregateExpandWithinDistinctRule;
import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
import org.apache.calcite.rel.rules.AggregateProjectConstantToDummyJoinRule;
Expand All @@ -78,6 +81,7 @@
import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
import org.apache.calcite.rel.rules.JoinAssociateRule;
import org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.calcite.rel.rules.LoptOptimizeJoinRule;
import org.apache.calcite.rel.rules.MeasureRules;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.rules.ProjectCorrelateTransposeRule;
Expand Down Expand Up @@ -9733,6 +9737,68 @@ private void checkJoinAssociateRuleWithTopAlwaysTrueCondition(boolean allowAlway
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6788">[CALCITE-6788]
* LoptOptimizeJoinRule should be able to delegate costs to the planner</a>. */
@Test void testLoptOptimizeJoinRuleWithDefaultCost() {
// Use the default rule
checkLoptOptimizeJoinRule(CoreRules.MULTI_JOIN_OPTIMIZE);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6788">[CALCITE-6788]
* LoptOptimizeJoinRule should be able to delegate costs to the planner</a>. */
@Test void testLoptOptimizeJoinRuleWithSpecialCost() {
// Use an ad-hoc version of the rule that uses planner#getCost instead of mq#getCumulativeCost
checkLoptOptimizeJoinRule(LoptOptimizeJoinRule.Config.DEFAULT
.withCostFunction((c, r) -> c.getPlanner().getCost(r, c.getMetadataQuery()))
.toRule());
}

private void checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) {
final HepProgram preProgram = new HepProgramBuilder()
.addMatchOrder(HepMatchOrder.BOTTOM_UP)
.addRuleInstance(CoreRules.JOIN_TO_MULTI_JOIN)
.build();

final HepProgram program = HepProgram.builder()
.addMatchOrder(HepMatchOrder.BOTTOM_UP)
.addRuleInstance(rule)
.build();

// Special planner that artificially favors joins on the same table
final HepPlanner planner = new HepPlanner(program) {
@Override public RelOptCost getCost(RelNode rel, RelMetadataQuery mq) {
if (rel instanceof Join
&& rel.getInput(0).stripped() instanceof TableScan
&& rel.getInput(1).stripped() instanceof TableScan) {
TableScan left = (TableScan) rel.getInput(0).stripped();
TableScan right = (TableScan) rel.getInput(1).stripped();
if (left.getTable().equals(right.getTable())) {
// Tiny cost for self-joins
return getCostFactory().makeTinyCost();
}
}

// General case: just define a kind of cumulative cost based on the rowCount (to avoid
// the infinite costs from the Logical operators)
RelOptCost cost = new RelOptCostImpl(mq.getRowCount(rel));
for (RelNode input : rel.getInputs()) {
cost = cost.plus(getCost(input, mq));
}
return cost;
}
};

sql("select e1.empno from emp e1"
+ " inner join dept d1 on d1.deptno = e1.deptno"
+ " inner join emp e2 on e1.ename = e2.ename"
+ " inner join dept d2 on d2.deptno = e1.deptno")
.withPre(preProgram)
.withPlanner(planner)
.check();
}

/**
* Test case of
* <a href="https://issues.apache.org/jira/browse/CALCITE-6850">[CALCITE-6850]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7419,6 +7419,62 @@ LogicalProject(EMPNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
<TestCase name="testLoptOptimizeJoinRuleWithDefaultCost">
<Resource name="sql">
<![CDATA[select e1.empno from emp e1 inner join dept d1 on d1.deptno = e1.deptno inner join emp e2 on e1.ename = e2.ename inner join dept d2 on d2.deptno = e1.deptno]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EMPNO=[$0])
MultiJoin(joinFilter=[AND(=($20, $7), =($1, $12), =($9, $7))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL, ALL]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalProject(EMPNO=[$0])
LogicalProject(EMPNO=[$9], ENAME=[$10], JOB=[$11], MGR=[$12], HIREDATE=[$13], SAL=[$14], COMM=[$15], DEPTNO=[$16], SLACKER=[$17], DEPTNO0=[$18], NAME=[$19], EMPNO0=[$0], ENAME0=[$1], JOB0=[$2], MGR0=[$3], HIREDATE0=[$4], SAL0=[$5], COMM0=[$6], DEPTNO1=[$7], SLACKER0=[$8], DEPTNO2=[$20], NAME0=[$21])
LogicalJoin(condition=[=($10, $1)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalJoin(condition=[=($11, $7)], joinType=[inner])
LogicalJoin(condition=[=($9, $7)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
<TestCase name="testLoptOptimizeJoinRuleWithSpecialCost">
<Resource name="sql">
<![CDATA[select e1.empno from emp e1 inner join dept d1 on d1.deptno = e1.deptno inner join emp e2 on e1.ename = e2.ename inner join dept d2 on d2.deptno = e1.deptno]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EMPNO=[$0])
MultiJoin(joinFilter=[AND(=($20, $7), =($1, $12), =($9, $7))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL, ALL]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalProject(EMPNO=[$0])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$18], NAME=[$19], EMPNO0=[$9], ENAME0=[$10], JOB0=[$11], MGR0=[$12], HIREDATE0=[$13], SAL0=[$14], COMM0=[$15], DEPTNO1=[$16], SLACKER0=[$17], DEPTNO2=[$20], NAME0=[$21])
LogicalJoin(condition=[=($20, $7)], joinType=[inner])
LogicalJoin(condition=[=($18, $7)], joinType=[inner])
LogicalJoin(condition=[=($1, $10)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
Expand Down

0 comments on commit 640c191

Please sign in to comment.