Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CALCITE-6788] LoptOptimizeJoinRule should be able to delegate costs to the planner #4231

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public LoptOptimizeJoinRule(RelFactories.JoinFactory joinFactory,

findRemovableSelfJoins(mq, multiJoin);

findBestOrderings(mq, call.builder(), multiJoin, semiJoinOpt, call);
findBestOrderings(call, multiJoin, semiJoinOpt);
}

/**
Expand Down Expand Up @@ -442,12 +442,10 @@ private static boolean isSelfJoinFilterUnique(
* @param semiJoinOpt optimal semijoins for each factor
* @param call RelOptRuleCall associated with this rule
*/
private static void findBestOrderings(
RelMetadataQuery mq,
RelBuilder relBuilder,
private void findBestOrderings(
RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
RelOptRuleCall call) {
LoptSemiJoinOptimizer semiJoinOpt) {
final List<RelNode> plans = new ArrayList<>();

final List<String> fieldNames =
Expand All @@ -461,8 +459,7 @@ private static void findBestOrderings(
}
LoptJoinTree joinTree =
createOrdering(
mq,
relBuilder,
call,
multiJoin,
semiJoinOpt,
i);
Expand Down Expand Up @@ -679,9 +676,8 @@ private static void setFactorJoinKeys(
* @return constructed join tree or null if it is not possible for
* firstFactor to appear as the first factor in the join
*/
private static @Nullable LoptJoinTree createOrdering(
RelMetadataQuery mq,
RelBuilder relBuilder,
private @Nullable LoptJoinTree createOrdering(
RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
int firstFactor) {
Expand Down Expand Up @@ -712,7 +708,7 @@ private static void setFactorJoinKeys(
} else {
nextFactor =
getBestNextFactor(
mq,
call.getMetadataQuery(),
multiJoin,
factorsToAdd,
factorsAdded,
Expand All @@ -733,8 +729,7 @@ private static void setFactorJoinKeys(
factorsNeeded.and(factorsAdded);
joinTree =
addFactorToTree(
mq,
relBuilder,
call,
multiJoin,
semiJoinOpt,
joinTree,
Expand Down Expand Up @@ -878,16 +873,17 @@ private static boolean isJoinTree(RelNode rel) {
* @return optimal join tree with the new factor added if it is possible to
* add the factor; otherwise, null is returned
*/
private static @Nullable LoptJoinTree addFactorToTree(
RelMetadataQuery mq,
RelBuilder relBuilder,
private @Nullable LoptJoinTree addFactorToTree(
RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
@Nullable LoptJoinTree joinTree,
int factorToAdd,
BitSet factorsNeeded,
List<RexNode> filtersToAdd,
boolean selfJoin) {
final RelMetadataQuery mq = call.getMetadataQuery();
final RelBuilder relBuilder = call.builder();

// if the factor corresponds to the null generating factor in an outer
// join that can be removed, then create a replacement join
Expand Down Expand Up @@ -943,8 +939,7 @@ private static boolean isJoinTree(RelNode rel) {
selfJoin);
LoptJoinTree pushDownTree =
pushDownFactor(
mq,
relBuilder,
call,
multiJoin,
semiJoinOpt,
joinTree,
Expand All @@ -959,10 +954,10 @@ private static boolean isJoinTree(RelNode rel) {
RelOptCost costPushDown = null;
RelOptCost costTop = null;
if (pushDownTree != null) {
costPushDown = mq.getCumulativeCost(pushDownTree.getJoinTree());
costPushDown = config.costFunction().getCost(call, pushDownTree.getJoinTree());
}
if (topTree != null) {
costTop = mq.getCumulativeCost(topTree.getJoinTree());
costTop = config.costFunction().getCost(call, topTree.getJoinTree());
}

if (pushDownTree == null) {
Expand Down Expand Up @@ -1035,9 +1030,8 @@ private static int rowWidthCost(RelNode tree) {
* join tree if it is possible to do the pushdown; otherwise, null is
* returned
*/
private static @Nullable LoptJoinTree pushDownFactor(
RelMetadataQuery mq,
RelBuilder relBuilder,
private @Nullable LoptJoinTree pushDownFactor(
RelOptRuleCall call,
LoptMultiJoin multiJoin,
LoptSemiJoinOptimizer semiJoinOpt,
LoptJoinTree joinTree,
Expand Down Expand Up @@ -1110,8 +1104,7 @@ private static int rowWidthCost(RelNode tree) {
LoptJoinTree subTree = (childNo == 0) ? left : right;
subTree =
addFactorToTree(
mq,
relBuilder,
call,
multiJoin,
semiJoinOpt,
subTree,
Expand Down Expand Up @@ -1165,8 +1158,8 @@ private static int rowWidthCost(RelNode tree) {

// create the new join tree with the factor pushed down
return createJoinSubtree(
mq,
relBuilder,
call.getMetadataQuery(),
call.builder(),
multiJoin,
left,
right,
Expand Down Expand Up @@ -2089,12 +2082,26 @@ private static boolean areSelfJoinKeysUnique(RelMetadataQuery mq,
joinInfo.leftSet());
}

/** Function to compute cost. */
@FunctionalInterface
public interface CostFunction {
@Nullable RelOptCost getCost(RelOptRuleCall call, RelNode relNode);
}

/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config DEFAULT = ImmutableLoptOptimizeJoinRule.Config.of()
.withOperandSupplier(b -> b.operand(MultiJoin.class).anyInputs());

/** Function to calculate intermediate cost computations. */
@Value.Default default CostFunction costFunction() {
return (call, rel) -> call.getMetadataQuery().getCumulativeCost(rel);
}

/** Sets {@link #costFunction()}. */
Config withCostFunction(CostFunction function);

@Override default LoptOptimizeJoinRule toRule() {
return new LoptOptimizeJoinRule(this);
}
Expand Down
66 changes: 66 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.calcite.config.CalciteConnectionConfig;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptCostImpl;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptRule;
Expand Down Expand Up @@ -53,6 +54,7 @@
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.hint.HintPredicates;
import org.apache.calcite.rel.hint.HintStrategyTable;
Expand All @@ -62,6 +64,7 @@
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalTableModify;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.AggregateExpandWithinDistinctRule;
import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
import org.apache.calcite.rel.rules.AggregateProjectConstantToDummyJoinRule;
Expand All @@ -78,6 +81,7 @@
import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
import org.apache.calcite.rel.rules.JoinAssociateRule;
import org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.calcite.rel.rules.LoptOptimizeJoinRule;
import org.apache.calcite.rel.rules.MeasureRules;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.rules.ProjectCorrelateTransposeRule;
Expand Down Expand Up @@ -9733,6 +9737,68 @@ private void checkJoinAssociateRuleWithTopAlwaysTrueCondition(boolean allowAlway
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6788">[CALCITE-6788]
* LoptOptimizeJoinRule should be able to delegate costs to the planner</a>. */
@Test void testLoptOptimizeJoinRuleWithDefaultCost() {
// Use the default rule
checkLoptOptimizeJoinRule(CoreRules.MULTI_JOIN_OPTIMIZE);
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6788">[CALCITE-6788]
* LoptOptimizeJoinRule should be able to delegate costs to the planner</a>. */
@Test void testLoptOptimizeJoinRuleWithSpecialCost() {
// Use an ad-hoc version of the rule that uses planner#getCost instead of mq#getCumulativeCost
checkLoptOptimizeJoinRule(LoptOptimizeJoinRule.Config.DEFAULT
.withCostFunction((c, r) -> c.getPlanner().getCost(r, c.getMetadataQuery()))
.toRule());
}

private void checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) {
final HepProgram preProgram = new HepProgramBuilder()
.addMatchOrder(HepMatchOrder.BOTTOM_UP)
.addRuleInstance(CoreRules.JOIN_TO_MULTI_JOIN)
.build();

final HepProgram program = HepProgram.builder()
.addMatchOrder(HepMatchOrder.BOTTOM_UP)
.addRuleInstance(rule)
.build();

// Special planner that artificially favors joins on the same table
final HepPlanner planner = new HepPlanner(program) {
@Override public RelOptCost getCost(RelNode rel, RelMetadataQuery mq) {
if (rel instanceof Join
&& rel.getInput(0).stripped() instanceof TableScan
&& rel.getInput(1).stripped() instanceof TableScan) {
TableScan left = (TableScan) rel.getInput(0).stripped();
TableScan right = (TableScan) rel.getInput(1).stripped();
if (left.getTable().equals(right.getTable())) {
// Tiny cost for self-joins
return getCostFactory().makeTinyCost();
}
}

// General case: just define a kind of cumulative cost based on the rowCount (to avoid
// the infinite costs from the Logical operators)
RelOptCost cost = new RelOptCostImpl(mq.getRowCount(rel));
for (RelNode input : rel.getInputs()) {
cost = cost.plus(getCost(input, mq));
}
return cost;
}
};

sql("select e1.empno from emp e1"
+ " inner join dept d1 on d1.deptno = e1.deptno"
+ " inner join emp e2 on e1.ename = e2.ename"
+ " inner join dept d2 on d2.deptno = e1.deptno")
.withPre(preProgram)
.withPlanner(planner)
.check();
}

/**
* Test case of
* <a href="https://issues.apache.org/jira/browse/CALCITE-6850">[CALCITE-6850]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7419,6 +7419,62 @@ LogicalProject(EMPNO=[$0])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
<TestCase name="testLoptOptimizeJoinRuleWithDefaultCost">
<Resource name="sql">
<![CDATA[select e1.empno from emp e1 inner join dept d1 on d1.deptno = e1.deptno inner join emp e2 on e1.ename = e2.ename inner join dept d2 on d2.deptno = e1.deptno]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EMPNO=[$0])
MultiJoin(joinFilter=[AND(=($20, $7), =($1, $12), =($9, $7))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL, ALL]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalProject(EMPNO=[$0])
LogicalProject(EMPNO=[$9], ENAME=[$10], JOB=[$11], MGR=[$12], HIREDATE=[$13], SAL=[$14], COMM=[$15], DEPTNO=[$16], SLACKER=[$17], DEPTNO0=[$18], NAME=[$19], EMPNO0=[$0], ENAME0=[$1], JOB0=[$2], MGR0=[$3], HIREDATE0=[$4], SAL0=[$5], COMM0=[$6], DEPTNO1=[$7], SLACKER0=[$8], DEPTNO2=[$20], NAME0=[$21])
LogicalJoin(condition=[=($10, $1)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalJoin(condition=[=($11, $7)], joinType=[inner])
LogicalJoin(condition=[=($9, $7)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
<TestCase name="testLoptOptimizeJoinRuleWithSpecialCost">
<Resource name="sql">
<![CDATA[select e1.empno from emp e1 inner join dept d1 on d1.deptno = e1.deptno inner join emp e2 on e1.ename = e2.ename inner join dept d2 on d2.deptno = e1.deptno]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalProject(EMPNO=[$0])
MultiJoin(joinFilter=[AND(=($20, $7), =($1, $12), =($9, $7))], isFullOuterJoin=[false], joinTypes=[[INNER, INNER, INNER, INNER]], outerJoinConditions=[[NULL, NULL, NULL, NULL]], projFields=[[ALL, ALL, ALL, ALL]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalProject(EMPNO=[$0])
LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], DEPTNO0=[$18], NAME=[$19], EMPNO0=[$9], ENAME0=[$10], JOB0=[$11], MGR0=[$12], HIREDATE0=[$13], SAL0=[$14], COMM0=[$15], DEPTNO1=[$16], SLACKER0=[$17], DEPTNO2=[$20], NAME0=[$21])
LogicalJoin(condition=[=($20, $7)], joinType=[inner])
LogicalJoin(condition=[=($18, $7)], joinType=[inner])
LogicalJoin(condition=[=($1, $10)], joinType=[inner])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
]]>
</Resource>
</TestCase>
Expand Down