Skip to content

Commit 640c191

Browse files
committed
[CALCITE-6788] LoptOptimizeJoinRule should be able to delegate costs to the planner
1 parent 802fce3 commit 640c191

File tree

3 files changed

+157
-28
lines changed

3 files changed

+157
-28
lines changed

core/src/main/java/org/apache/calcite/rel/rules/LoptOptimizeJoinRule.java

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ public LoptOptimizeJoinRule(RelFactories.JoinFactory joinFactory,
132132

133133
findRemovableSelfJoins(mq, multiJoin);
134134

135-
findBestOrderings(mq, call.builder(), multiJoin, semiJoinOpt, call);
135+
findBestOrderings(call, multiJoin, semiJoinOpt);
136136
}
137137

138138
/**
@@ -442,12 +442,10 @@ private static boolean isSelfJoinFilterUnique(
442442
* @param semiJoinOpt optimal semijoins for each factor
443443
* @param call RelOptRuleCall associated with this rule
444444
*/
445-
private static void findBestOrderings(
446-
RelMetadataQuery mq,
447-
RelBuilder relBuilder,
445+
private void findBestOrderings(
446+
RelOptRuleCall call,
448447
LoptMultiJoin multiJoin,
449-
LoptSemiJoinOptimizer semiJoinOpt,
450-
RelOptRuleCall call) {
448+
LoptSemiJoinOptimizer semiJoinOpt) {
451449
final List<RelNode> plans = new ArrayList<>();
452450

453451
final List<String> fieldNames =
@@ -461,8 +459,7 @@ private static void findBestOrderings(
461459
}
462460
LoptJoinTree joinTree =
463461
createOrdering(
464-
mq,
465-
relBuilder,
462+
call,
466463
multiJoin,
467464
semiJoinOpt,
468465
i);
@@ -679,9 +676,8 @@ private static void setFactorJoinKeys(
679676
* @return constructed join tree or null if it is not possible for
680677
* firstFactor to appear as the first factor in the join
681678
*/
682-
private static @Nullable LoptJoinTree createOrdering(
683-
RelMetadataQuery mq,
684-
RelBuilder relBuilder,
679+
private @Nullable LoptJoinTree createOrdering(
680+
RelOptRuleCall call,
685681
LoptMultiJoin multiJoin,
686682
LoptSemiJoinOptimizer semiJoinOpt,
687683
int firstFactor) {
@@ -712,7 +708,7 @@ private static void setFactorJoinKeys(
712708
} else {
713709
nextFactor =
714710
getBestNextFactor(
715-
mq,
711+
call.getMetadataQuery(),
716712
multiJoin,
717713
factorsToAdd,
718714
factorsAdded,
@@ -733,8 +729,7 @@ private static void setFactorJoinKeys(
733729
factorsNeeded.and(factorsAdded);
734730
joinTree =
735731
addFactorToTree(
736-
mq,
737-
relBuilder,
732+
call,
738733
multiJoin,
739734
semiJoinOpt,
740735
joinTree,
@@ -878,16 +873,17 @@ private static boolean isJoinTree(RelNode rel) {
878873
* @return optimal join tree with the new factor added if it is possible to
879874
* add the factor; otherwise, null is returned
880875
*/
881-
private static @Nullable LoptJoinTree addFactorToTree(
882-
RelMetadataQuery mq,
883-
RelBuilder relBuilder,
876+
private @Nullable LoptJoinTree addFactorToTree(
877+
RelOptRuleCall call,
884878
LoptMultiJoin multiJoin,
885879
LoptSemiJoinOptimizer semiJoinOpt,
886880
@Nullable LoptJoinTree joinTree,
887881
int factorToAdd,
888882
BitSet factorsNeeded,
889883
List<RexNode> filtersToAdd,
890884
boolean selfJoin) {
885+
final RelMetadataQuery mq = call.getMetadataQuery();
886+
final RelBuilder relBuilder = call.builder();
891887

892888
// if the factor corresponds to the null generating factor in an outer
893889
// join that can be removed, then create a replacement join
@@ -943,8 +939,7 @@ private static boolean isJoinTree(RelNode rel) {
943939
selfJoin);
944940
LoptJoinTree pushDownTree =
945941
pushDownFactor(
946-
mq,
947-
relBuilder,
942+
call,
948943
multiJoin,
949944
semiJoinOpt,
950945
joinTree,
@@ -959,10 +954,10 @@ private static boolean isJoinTree(RelNode rel) {
959954
RelOptCost costPushDown = null;
960955
RelOptCost costTop = null;
961956
if (pushDownTree != null) {
962-
costPushDown = mq.getCumulativeCost(pushDownTree.getJoinTree());
957+
costPushDown = config.costFunction().getCost(call, pushDownTree.getJoinTree());
963958
}
964959
if (topTree != null) {
965-
costTop = mq.getCumulativeCost(topTree.getJoinTree());
960+
costTop = config.costFunction().getCost(call, topTree.getJoinTree());
966961
}
967962

968963
if (pushDownTree == null) {
@@ -1035,9 +1030,8 @@ private static int rowWidthCost(RelNode tree) {
10351030
* join tree if it is possible to do the pushdown; otherwise, null is
10361031
* returned
10371032
*/
1038-
private static @Nullable LoptJoinTree pushDownFactor(
1039-
RelMetadataQuery mq,
1040-
RelBuilder relBuilder,
1033+
private @Nullable LoptJoinTree pushDownFactor(
1034+
RelOptRuleCall call,
10411035
LoptMultiJoin multiJoin,
10421036
LoptSemiJoinOptimizer semiJoinOpt,
10431037
LoptJoinTree joinTree,
@@ -1110,8 +1104,7 @@ private static int rowWidthCost(RelNode tree) {
11101104
LoptJoinTree subTree = (childNo == 0) ? left : right;
11111105
subTree =
11121106
addFactorToTree(
1113-
mq,
1114-
relBuilder,
1107+
call,
11151108
multiJoin,
11161109
semiJoinOpt,
11171110
subTree,
@@ -1165,8 +1158,8 @@ private static int rowWidthCost(RelNode tree) {
11651158

11661159
// create the new join tree with the factor pushed down
11671160
return createJoinSubtree(
1168-
mq,
1169-
relBuilder,
1161+
call.getMetadataQuery(),
1162+
call.builder(),
11701163
multiJoin,
11711164
left,
11721165
right,
@@ -2089,12 +2082,26 @@ private static boolean areSelfJoinKeysUnique(RelMetadataQuery mq,
20892082
joinInfo.leftSet());
20902083
}
20912084

2085+
/** Function to compute cost. */
2086+
@FunctionalInterface
2087+
public interface CostFunction {
2088+
@Nullable RelOptCost getCost(RelOptRuleCall call, RelNode relNode);
2089+
}
2090+
20922091
/** Rule configuration. */
20932092
@Value.Immutable
20942093
public interface Config extends RelRule.Config {
20952094
Config DEFAULT = ImmutableLoptOptimizeJoinRule.Config.of()
20962095
.withOperandSupplier(b -> b.operand(MultiJoin.class).anyInputs());
20972096

2097+
/** Function to calculate intermediate cost computations. */
2098+
@Value.Default default CostFunction costFunction() {
2099+
return (call, rel) -> call.getMetadataQuery().getCumulativeCost(rel);
2100+
}
2101+
2102+
/** Sets {@link #costFunction()}. */
2103+
Config withCostFunction(CostFunction function);
2104+
20982105
@Override default LoptOptimizeJoinRule toRule() {
20992106
return new LoptOptimizeJoinRule(this);
21002107
}

core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.calcite.config.CalciteConnectionConfig;
2727
import org.apache.calcite.plan.Contexts;
2828
import org.apache.calcite.plan.RelOptCluster;
29+
import org.apache.calcite.plan.RelOptCost;
2930
import org.apache.calcite.plan.RelOptCostImpl;
3031
import org.apache.calcite.plan.RelOptPlanner;
3132
import org.apache.calcite.plan.RelOptRule;
@@ -53,6 +54,7 @@
5354
import org.apache.calcite.rel.core.JoinRelType;
5455
import org.apache.calcite.rel.core.Minus;
5556
import org.apache.calcite.rel.core.Project;
57+
import org.apache.calcite.rel.core.TableScan;
5658
import org.apache.calcite.rel.core.Union;
5759
import org.apache.calcite.rel.hint.HintPredicates;
5860
import org.apache.calcite.rel.hint.HintStrategyTable;
@@ -62,6 +64,7 @@
6264
import org.apache.calcite.rel.logical.LogicalFilter;
6365
import org.apache.calcite.rel.logical.LogicalProject;
6466
import org.apache.calcite.rel.logical.LogicalTableModify;
67+
import org.apache.calcite.rel.metadata.RelMetadataQuery;
6568
import org.apache.calcite.rel.rules.AggregateExpandWithinDistinctRule;
6669
import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
6770
import org.apache.calcite.rel.rules.AggregateProjectConstantToDummyJoinRule;
@@ -78,6 +81,7 @@
7881
import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
7982
import org.apache.calcite.rel.rules.JoinAssociateRule;
8083
import org.apache.calcite.rel.rules.JoinCommuteRule;
84+
import org.apache.calcite.rel.rules.LoptOptimizeJoinRule;
8185
import org.apache.calcite.rel.rules.MeasureRules;
8286
import org.apache.calcite.rel.rules.MultiJoin;
8387
import org.apache.calcite.rel.rules.ProjectCorrelateTransposeRule;
@@ -9733,6 +9737,68 @@ private void checkJoinAssociateRuleWithTopAlwaysTrueCondition(boolean allowAlway
97339737
.check();
97349738
}
97359739

9740+
/** Test case for
9741+
* <a href="https://issues.apache.org/jira/browse/CALCITE-6788">[CALCITE-6788]
9742+
* LoptOptimizeJoinRule should be able to delegate costs to the planner</a>. */
9743+
@Test void testLoptOptimizeJoinRuleWithDefaultCost() {
9744+
// Use the default rule
9745+
checkLoptOptimizeJoinRule(CoreRules.MULTI_JOIN_OPTIMIZE);
9746+
}
9747+
9748+
/** Test case for
9749+
* <a href="https://issues.apache.org/jira/browse/CALCITE-6788">[CALCITE-6788]
9750+
* LoptOptimizeJoinRule should be able to delegate costs to the planner</a>. */
9751+
@Test void testLoptOptimizeJoinRuleWithSpecialCost() {
9752+
// Use an ad-hoc version of the rule that uses planner#getCost instead of mq#getCumulativeCost
9753+
checkLoptOptimizeJoinRule(LoptOptimizeJoinRule.Config.DEFAULT
9754+
.withCostFunction((c, r) -> c.getPlanner().getCost(r, c.getMetadataQuery()))
9755+
.toRule());
9756+
}
9757+
9758+
private void checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) {
9759+
final HepProgram preProgram = new HepProgramBuilder()
9760+
.addMatchOrder(HepMatchOrder.BOTTOM_UP)
9761+
.addRuleInstance(CoreRules.JOIN_TO_MULTI_JOIN)
9762+
.build();
9763+
9764+
final HepProgram program = HepProgram.builder()
9765+
.addMatchOrder(HepMatchOrder.BOTTOM_UP)
9766+
.addRuleInstance(rule)
9767+
.build();
9768+
9769+
// Special planner that artificially favors joins on the same table
9770+
final HepPlanner planner = new HepPlanner(program) {
9771+
@Override public RelOptCost getCost(RelNode rel, RelMetadataQuery mq) {
9772+
if (rel instanceof Join
9773+
&& rel.getInput(0).stripped() instanceof TableScan
9774+
&& rel.getInput(1).stripped() instanceof TableScan) {
9775+
TableScan left = (TableScan) rel.getInput(0).stripped();
9776+
TableScan right = (TableScan) rel.getInput(1).stripped();
9777+
if (left.getTable().equals(right.getTable())) {
9778+
// Tiny cost for self-joins
9779+
return getCostFactory().makeTinyCost();
9780+
}
9781+
}
9782+
9783+
// General case: just define a kind of cumulative cost based on the rowCount (to avoid
9784+
// the infinite costs from the Logical operators)
9785+
RelOptCost cost = new RelOptCostImpl(mq.getRowCount(rel));
9786+
for (RelNode input : rel.getInputs()) {
9787+
cost = cost.plus(getCost(input, mq));
9788+
}
9789+
return cost;
9790+
}
9791+
};
9792+
9793+
sql("select e1.empno from emp e1"
9794+
+ " inner join dept d1 on d1.deptno = e1.deptno"
9795+
+ " inner join emp e2 on e1.ename = e2.ename"
9796+
+ " inner join dept d2 on d2.deptno = e1.deptno")
9797+
.withPre(preProgram)
9798+
.withPlanner(planner)
9799+
.check();
9800+
}
9801+
97369802
/**
97379803
* Test case of
97389804
* <a href="https://issues.apache.org/jira/browse/CALCITE-6850">[CALCITE-6850]

core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7419,6 +7419,62 @@ LogicalProject(EMPNO=[$0])
74197419
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
74207420
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
74217421
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
7422+
]]>
7423+
</Resource>
7424+
</TestCase>
7425+
<TestCase name="testLoptOptimizeJoinRuleWithDefaultCost">
7426+
<Resource name="sql">
7427+
<![CDATA[select e1.empno from emp e1 inner join dept d1 on d1.deptno = e1.deptno inner join emp e2 on e1.ename = e2.ename inner join dept d2 on d2.deptno = e1.deptno]]>
7428+
</Resource>
7429+
<Resource name="planBefore">
7430+
<![CDATA[
7431+
LogicalProject(EMPNO=[$0])
7432+
MultiJoin(joinFilter=[AND(=($20, $7), =($1, $12), =($9, $7))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL, ALL]])
7433+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
7434+
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
7435+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
7436+
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
7437+
]]>
7438+
</Resource>
7439+
<Resource name="planAfter">
7440+
<![CDATA[
7441+
LogicalProject(EMPNO=[$0])
7442+
LogicalProject(EMPNO=[$9], ENAME=[$10], JOB=[$11], MGR=[$12], HIREDATE=[$13], SAL=[$14], COMM=[$15], DEPTNO=[$16], SLACKER=[$17], DEPTNO0=[$18], NAME=[$19], EMPNO0=[$0], ENAME0=[$1], JOB0=[$2], MGR0=[$3], HIREDATE0=[$4], SAL0=[$5], COMM0=[$6], DEPTNO1=[$7], SLACKER0=[$8], DEPTNO2=[$20], NAME0=[$21])
7443+
LogicalJoin(condition=[=($10, $1)], joinType=[inner])
7444+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
7445+
LogicalJoin(condition=[=($11, $7)], joinType=[inner])
7446+
LogicalJoin(condition=[=($9, $7)], joinType=[inner])
7447+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
7448+
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
7449+
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
7450+
]]>
7451+
</Resource>
7452+
</TestCase>
7453+
<TestCase name="testLoptOptimizeJoinRuleWithSpecialCost">
7454+
<Resource name="sql">
7455+
<![CDATA[select e1.empno from emp e1 inner join dept d1 on d1.deptno = e1.deptno inner join emp e2 on e1.ename = e2.ename inner join dept d2 on d2.deptno = e1.deptno]]>
7456+
</Resource>
7457+
<Resource name="planBefore">
7458+
<![CDATA[
7459+
LogicalProject(EMPNO=[$0])
7460+
MultiJoin(joinFilter=[AND(=($20, $7), =($1, $12), =($9, $7))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL, ALL]])
7461+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
7462+
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
7463+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
7464+
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
7465+
]]>
7466+
</Resource>
7467+
<Resource name="planAfter">
7468+
<![CDATA[
7469+
LogicalProject(EMPNO=[$0])
7470+
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$18], NAME=[$19], EMPNO0=[$9], ENAME0=[$10], JOB0=[$11], MGR0=[$12], HIREDATE0=[$13], SAL0=[$14], COMM0=[$15], DEPTNO1=[$16], SLACKER0=[$17], DEPTNO2=[$20], NAME0=[$21])
7471+
LogicalJoin(condition=[=($20, $7)], joinType=[inner])
7472+
LogicalJoin(condition=[=($18, $7)], joinType=[inner])
7473+
LogicalJoin(condition=[=($1, $10)], joinType=[inner])
7474+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
7475+
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
7476+
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
7477+
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
74227478
]]>
74237479
</Resource>
74247480
</TestCase>

0 commit comments

Comments
 (0)