diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SubQueryDecorrelator.java b/core/src/main/java/org/apache/calcite/rel/rules/SubQueryDecorrelator.java new file mode 100644 index 00000000000..6a60bedbe6a --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/SubQueryDecorrelator.java @@ -0,0 +1,1447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.plan.Contexts; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.hep.HepRelVertex; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelShuttleImpl; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.CorrelationId; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.core.SetOp; +import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.core.Values; +import org.apache.calcite.rel.hint.Hintable; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.logical.LogicalCorrelate; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalIntersect; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.LogicalMinus; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.logical.LogicalSort; +import org.apache.calcite.rel.logical.LogicalUnion; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexOver; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexSubQuery; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitor; +import org.apache.calcite.rex.RexVisitorImpl; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.Bug; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.Litmus; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.ReflectUtil; +import org.apache.calcite.util.ReflectiveVisitor; +import org.apache.calcite.util.Util; +import org.apache.calcite.util.mapping.Mappings; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Objects; +import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; +import java.util.TreeSet; + +/** + * SubQueryDecorrelator finds all correlated expressions in a SubQuery, and gets an equivalent + * non-correlated relational expression tree and correlation conditions. + * + *

The Basic idea of SubQueryDecorrelator is from {@link + * org.apache.calcite.sql2rel.RelDecorrelator}, however there are differences between them: 1. This + * class works with {@link RexSubQuery}, while RelDecorrelator works with {@link LogicalCorrelate}. + * 2. This class will get an equivalent non-correlated expressions tree and correlation conditions, + * while RelDecorrelator will replace all correlated expressions with non-correlated expressions + * that are produced from joining the RelNode. 3. This class supports both equi and non-equi + * correlation conditions, while RelDecorrelator only supports equi correlation conditions. + */ +public class SubQueryDecorrelator extends RelShuttleImpl { + + private final SubQueryRelDecorrelator decorrelator; + private final RelBuilder relBuilder; + + // map a SubQuery to an equivalent RelNode and correlation-condition pair + private final Map> subQueryMap = new HashMap<>(); + + private SubQueryDecorrelator(SubQueryRelDecorrelator decorrelator, RelBuilder relBuilder) { + this.decorrelator = decorrelator; + this.relBuilder = relBuilder; + } + + /** + * Decorrelates a subquery. + * + *

This is the main entry point to {@code SubQueryDecorrelator}. + * + * @param rootRel The node which has SubQuery. + * @return Decorrelate result. + */ + public static Result decorrelateQuery(RelNode rootRel) { + int maxCnfNodeCount = -1; + + final CorelMapBuilder builder = new CorelMapBuilder(maxCnfNodeCount); + final CorelMap corelMap = builder.build(rootRel); + if (builder.hasNestedCorScope || builder.hasUnsupportedCorCondition) { + return null; + } + + if (!corelMap.hasCorrelation()) { + return Result.EMPTY; + } + + RelOptCluster cluster = rootRel.getCluster(); + RelBuilder relBuilder = RelBuilder.proto(Contexts.of()).create(cluster, null); + RexBuilder rexBuilder = cluster.getRexBuilder(); + + final SubQueryDecorrelator decorrelator = + new SubQueryDecorrelator( + new SubQueryRelDecorrelator(corelMap, relBuilder, rexBuilder, maxCnfNodeCount), + relBuilder); + rootRel.accept(decorrelator); + + return new Result(decorrelator.subQueryMap); + } + + @Override protected RelNode visitChild(RelNode parent, int i, RelNode input) { + return super.visitChild(parent, i, stripHep(input)); + } + + @Override public RelNode visit(final LogicalFilter filter) { + try { + stack.push(filter); + filter.getCondition().accept(handleSubQuery(filter)); + } finally { + stack.pop(); + } + return super.visit(filter); + } + + private RexVisitorImpl handleSubQuery(final RelNode rel) { + return new RexVisitorImpl(true) { + + @Override public Void visitSubQuery(RexSubQuery subQuery) { + RelNode newRel = subQuery.rel; + if (subQuery.getKind() == SqlKind.IN) { + newRel = addProjectionForIn(subQuery.rel); + } + final Frame frame = decorrelator.getInvoke(newRel); + if (frame != null && frame.c != null) { + + Frame target = frame; + if (subQuery.getKind() == SqlKind.EXISTS) { + target = addProjectionForExists(frame); + } + + final DecorrelateRexShuttle shuttle = + new DecorrelateRexShuttle(rel.getRowType(), + target.r.getRowType(), + rel.getVariablesSet()); + + final RexNode newCondition = target.c.accept(shuttle); + Pair newNodeAndCondition = new Pair<>(target.r, newCondition); + subQueryMap.put(subQuery, newNodeAndCondition); + } + return null; + } + }; + } + + /** + * Adds Projection to adjust the field index for join condition. + * + *

e.g. SQL: SELECT * FROM l WHERE b IN (SELECT COUNT(*) FROM r WHERE l.c = r.f the rel in + * SubQuery is `LogicalAggregate(group=[{}], EXPR$1=[COUNT()])`. After decorrelated, it was + * changed to `LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])`, and the output index of + * `COUNT()` was changed from 0 to 1. So, add a project (`LogicalProject(EXPR$0=[$1], f=[$0])`) + * to adjust output fields order. + */ + private RelNode addProjectionForIn(RelNode relNode) { + if (relNode instanceof LogicalProject) { + return relNode; + } + + RelDataType rowType = relNode.getRowType(); + final List projects = new ArrayList<>(); + for (int i = 0; i < rowType.getFieldCount(); ++i) { + projects.add(RexInputRef.of(i, rowType)); + } + + relBuilder.clear(); + relBuilder.push(relNode); + relBuilder.project(projects, rowType.getFieldNames(), true); + return relBuilder.build(); + } + + /** Adds Projection to choose the fields used by join condition. */ + private Frame addProjectionForExists(Frame frame) { + final List corIndices = new ArrayList<>(frame.getCorInputRefIndices()); + final RelNode rel = frame.r; + final RelDataType rowType = rel.getRowType(); + if (corIndices.size() == rowType.getFieldCount()) { + // no need projection + return frame; + } + + final List projects = new ArrayList<>(); + final Map mapInputToOutput = new HashMap<>(); + + Collections.sort(corIndices); + int newPos = 0; + for (int index : corIndices) { + projects.add(RexInputRef.of(index, rowType)); + mapInputToOutput.put(index, newPos++); + } + + relBuilder.clear(); + relBuilder.push(frame.r); + relBuilder.project(projects); + final RelNode newProject = relBuilder.build(); + final RexNode newCondition = + adjustInputRefs(frame.c, mapInputToOutput, newProject.getRowType()); + + // There is no old RelNode corresponding to newProject, so oldToNewOutputs is empty. + return new Frame(rel, newProject, newCondition, new HashMap<>()); + } + + private static RelNode stripHep(RelNode rel) { + if (rel instanceof HepRelVertex) { + HepRelVertex hepRelVertex = (HepRelVertex) rel; + rel = hepRelVertex.getCurrentRel(); + } + return rel; + } + + private static void analyzeCorConditions( + final Set variableSet, + final RexNode condition, + final RexBuilder rexBuilder, + final int maxCnfNodeCount, + final List corConditions, + final List nonCorConditions, + final List unsupportedCorConditions) { + // converts the expanded expression to conjunctive normal form, + // like "(a AND b) OR c" will be converted to "(a OR c) AND (b OR c)" + final RexNode cnf = RexUtil.toCnf(rexBuilder, maxCnfNodeCount, condition); + // converts the cnf condition to a list of AND conditions + final List conjunctions = RelOptUtil.conjunctions(cnf); + // `true` for RexNode is supported correlation condition, + // `false` for RexNode is unsupported correlation condition, + // `null` for RexNode is not a correlation condition. + final RexVisitorImpl visitor = new RexVisitorImpl(true) { + + @Override public Boolean visitFieldAccess(RexFieldAccess fieldAccess) { + final RexNode ref = fieldAccess.getReferenceExpr(); + if (ref instanceof RexCorrelVariable) { + return visitCorrelVariable((RexCorrelVariable) ref); + } else { + return super.visitFieldAccess(fieldAccess); + } + } + + @Override public Boolean visitCorrelVariable(RexCorrelVariable correlVariable) { + return variableSet.contains(correlVariable.id); + } + + @Override public Boolean visitSubQuery(RexSubQuery subQuery) { + final List result = new ArrayList<>(); + for (RexNode operand : subQuery.operands) { + result.add(operand.accept(this)); + } + // we do not support nested correlation variables in SubQuery, such as: + // select * from t1 where exists(select * from t2 where t1.a = t2.c and t1.b + // in (select t3.d from t3) + if (result.contains(true) || result.contains(false)) { + return false; + } else { + return null; + } + } + + @Override public Boolean visitCall(RexCall call) { + final List result = new ArrayList<>(); + for (RexNode operand : call.operands) { + result.add(operand.accept(this)); + } + if (result.contains(false)) { + return false; + } else if (result.contains(true)) { + // TODO supports correlation variable with OR + // return call.op.getKind() != SqlKind.OR || !result.contains(null); + return call.op.getKind() != SqlKind.OR; + } else { + return null; + } + } + }; + + for (RexNode c : conjunctions) { + Boolean r = c.accept(visitor); + if (r == null) { + nonCorConditions.add(c); + } else if (r) { + corConditions.add(c); + } else { + unsupportedCorConditions.add(c); + } + } + } + + /** + * Adjust the condition's field indices according to mapOldToNewIndex. + * + * @param c The condition to be adjusted. + * @param mapOldToNewIndex A map containing the mapping the old field indices to new field + * indices. + * @param rowType The row type of the new output. + * @return Return new condition with new field indices. + */ + private static RexNode adjustInputRefs(final RexNode c, + final Map mapOldToNewIndex, final RelDataType rowType) { + return c.accept(new RexShuttle() { + + @Override public RexNode visitInputRef(RexInputRef inputRef) { + assert mapOldToNewIndex.containsKey(inputRef.getIndex()); + int newIndex = mapOldToNewIndex.get(inputRef.getIndex()); + final RexInputRef ref = RexInputRef.of(newIndex, rowType); + if (ref.getIndex() == inputRef.getIndex() && ref.getType() == inputRef.getType()) { + return inputRef; // re-use old object, to prevent needless expr cloning + } else { + return ref; + } + } + }); + } + + /** + * Shuttle that shifts inputs of correlations. + */ + private static class DecorrelateRexShuttle extends RexShuttle { + + private final RelDataType leftRowType; + private final RelDataType rightRowType; + private final Set variableSet; + + private DecorrelateRexShuttle(RelDataType leftRowType, RelDataType rightRowType, + Set variableSet) { + this.leftRowType = leftRowType; + this.rightRowType = rightRowType; + this.variableSet = variableSet; + } + + @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + final RexNode ref = fieldAccess.getReferenceExpr(); + if (ref instanceof RexCorrelVariable) { + final RexCorrelVariable var = (RexCorrelVariable) ref; + assert variableSet.contains(var.id); + final RelDataTypeField field = fieldAccess.getField(); + return new RexInputRef(field.getIndex(), field.getType()); + } else { + return super.visitFieldAccess(fieldAccess); + } + } + + @Override public RexNode visitInputRef(RexInputRef inputRef) { + assert inputRef.getIndex() < rightRowType.getFieldCount(); + int newIndex = inputRef.getIndex() + leftRowType.getFieldCount(); + return new RexInputRef(newIndex, inputRef.getType()); + } + } + + /** + * Pull out all correlation conditions from a given subquery to top level, and rebuild the + * subquery rel tree without correlation conditions. + * + *

`public` is for reflection. We use ReflectiveVisitor instead of RelShuttle because + * RelShuttle returns RelNode. + */ + public static class SubQueryRelDecorrelator implements ReflectiveVisitor { + + // map built during translation + private final CorelMap cm; + private final RelBuilder relBuilder; + private final RexBuilder rexBuilder; + private final ReflectUtil.MethodDispatcher dispatcher = + ReflectUtil.createMethodDispatcher(Frame.class, + this, + "decorrelateRel", + RelNode.class); + private final int maxCnfNodeCount; + + SubQueryRelDecorrelator(CorelMap cm, RelBuilder relBuilder, RexBuilder rexBuilder, + int maxCnfNodeCount) { + this.cm = cm; + this.relBuilder = relBuilder; + this.rexBuilder = rexBuilder; + this.maxCnfNodeCount = maxCnfNodeCount; + } + + Frame getInvoke(RelNode r) { + return dispatcher.invoke(r); + } + + /** + * Rewrite LogicalProject. + * + *

Rewrite logic: Pass along any correlated variables coming from the input. + * + * @param rel the project rel to rewrite + */ + public Frame decorrelateRel(LogicalProject rel) { + final RelNode oldInput = rel.getInput(); + Frame frame = getInvoke(oldInput); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; + } + + final List oldProjects = rel.getProjects(); + final List relOutput = rel.getRowType().getFieldList(); + final RelNode newInput = frame.r; + + // Project projects the original expressions, + // plus any correlated variables the input wants to pass along. + final List> projects = new ArrayList<>(); + + // If this Project has correlated reference, produce the correlated variables in the new + // output. + // TODO Currently, correlation in projection is not supported. + assert !cm.mapRefRelToCorRef.containsKey(rel); + + final Map mapInputToOutput = new HashMap<>(); + final Map mapOldToNewOutputs = new HashMap<>(); + // Project projects the original expressions + int newPos; + for (newPos = 0; newPos < oldProjects.size(); newPos++) { + RexNode project = + adjustInputRefs(oldProjects.get(newPos), frame.oldToNewOutputs, newInput.getRowType()); + projects.add(newPos, Pair.of(project, relOutput.get(newPos).getName())); + mapOldToNewOutputs.put(newPos, newPos); + if (project instanceof RexInputRef) { + mapInputToOutput.put(((RexInputRef) project).getIndex(), newPos); + } + } + + if (frame.c != null) { + // Project any correlated variables the input wants to pass along. + final ImmutableBitSet corInputIndices = RelOptUtil.InputFinder.bits(frame.c); + final RelDataType inputRowType = newInput.getRowType(); + for (int inputIndex : corInputIndices.toList()) { + if (!mapInputToOutput.containsKey(inputIndex)) { + projects.add( + newPos, + Pair.of(RexInputRef.of(inputIndex, inputRowType), + inputRowType.getFieldNames().get(inputIndex))); + mapInputToOutput.put(inputIndex, newPos); + newPos++; + } + } + } + + RelNode newProject = RelFactories.LOGICAL_BUILDER.create(newInput.getCluster(), null) + .push(newInput).projectNamed(Pair.left(projects), Pair.right(projects), false) + .build(); + newProject = ((LogicalProject) newProject).withHints(rel.getHints()); + + final RexNode newCorCondition; + if (frame.c != null) { + newCorCondition = adjustInputRefs(frame.c, mapInputToOutput, newProject.getRowType()); + } else { + newCorCondition = null; + } + + return new Frame(rel, newProject, newCorCondition, mapOldToNewOutputs); + } + + /** + * Rewrite LogicalFilter. + * + *

Rewrite logic: 1. If a Filter references a correlated field in its filter condition, + * rewrite the Filter references only non-correlated fields, and the condition references + * correlated fields will be push to it's output. 2. If Filter does not reference correlated + * variables, simply rewrite the filter condition using new input. + * + * @param rel the filter rel to rewrite + */ + public Frame decorrelateRel(LogicalFilter rel) { + final RelNode oldInput = rel.getInput(); + Frame frame = getInvoke(oldInput); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; + } + + // Conditions reference only correlated fields + final List corConditions = new ArrayList<>(); + // Conditions do not reference any correlated fields + final List nonCorConditions = new ArrayList<>(); + // Conditions reference correlated fields, but not supported now + final List unsupportedCorConditions = new ArrayList<>(); + + analyzeCorConditions( + cm.mapSubQueryNodeToCorSet.get(rel), + rel.getCondition(), + rexBuilder, + maxCnfNodeCount, + corConditions, + nonCorConditions, + unsupportedCorConditions); + assert unsupportedCorConditions.isEmpty(); + + final RexNode remainingCondition = + RexUtil.composeConjunction(rexBuilder, nonCorConditions, false); + + // Using LogicalFilter.create instead of RelBuilder.filter to create Filter + // because RelBuilder.filter method does not have VariablesSet arg. + final RelNode newFilter = + LogicalFilter.create(frame.r, + remainingCondition, + com.google.common.collect.ImmutableSet.copyOf(rel.getVariablesSet())) + .withHints(rel.getHints()); + + // Adds input's correlation condition + if (frame.c != null) { + corConditions.add(frame.c); + } + + final RexNode corCondition = RexUtil.composeConjunction(rexBuilder, corConditions, true); + // Filter does not change the input ordering. + // All corVars produced by filter will have the same output positions in the input rel. + return new Frame(rel, newFilter, corCondition, frame.oldToNewOutputs); + } + + /** + * Rewrites a {@link LogicalAggregate}. + * + *

Rewrite logic: 1. Permute the group by keys to the front. 2. If the input of an + * aggregate produces correlated variables, add them to the group list. 3. Change aggCalls + * to reference the new project. + * + * @param rel Aggregate to rewrite + */ + public Frame decorrelateRel(LogicalAggregate rel) { + // Aggregate itself should not reference corVars. + assert !cm.mapRefRelToCorRef.containsKey(rel); + + final RelNode oldInput = rel.getInput(); + final Frame frame = getInvoke(oldInput); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; + } + + final RelNode newInput = frame.r; + // map from newInput + final Map mapNewInputToProjOutputs = new HashMap<>(); + final int oldGroupKeyCount = rel.getGroupSet().cardinality(); + + // Project projects the original expressions, + // plus any correlated variables the input wants to pass along. + final List> projects = new ArrayList<>(); + final List newInputOutput = newInput.getRowType().getFieldList(); + + // oldInput has the original group by keys in the front. + final NavigableMap omittedConstants = new TreeMap<>(); + int newPos = 0; + for (int i = 0; i < oldGroupKeyCount; i++) { + final RexLiteral constant = projectedLiteral(newInput, i); + if (constant != null) { + // Exclude constants. Aggregate({true}) occurs because Aggregate({}) + // would generate 1 row even when applied to an empty table. + omittedConstants.put(i, constant); + continue; + } + + int newInputPos = frame.oldToNewOutputs.get(i); + projects.add(newPos, RexInputRef.of2(newInputPos, newInputOutput)); + mapNewInputToProjOutputs.put(newInputPos, newPos); + newPos++; + } + + if (frame.c != null) { + // If input produces correlated variables, move them to the front, + // right after any existing GROUP BY fields. + + // Now add the corVars from the input, starting from position oldGroupKeyCount. + for (Integer index : frame.getCorInputRefIndices()) { + if (!mapNewInputToProjOutputs.containsKey(index)) { + projects.add(newPos, RexInputRef.of2(index, newInputOutput)); + mapNewInputToProjOutputs.put(index, newPos); + newPos++; + } + } + } + + // add the remaining fields + final int newGroupKeyCount = newPos; + for (int i = 0; i < newInputOutput.size(); i++) { + if (!mapNewInputToProjOutputs.containsKey(i)) { + projects.add(newPos, RexInputRef.of2(i, newInputOutput)); + mapNewInputToProjOutputs.put(i, newPos); + newPos++; + } + } + + assert newPos == newInputOutput.size(); + + // This Project will be what the old input maps to, + // replacing any previous mapping from old input). + final RelNode newProject = RelFactories.LOGICAL_BUILDER.create(newInput.getCluster(), null) + .push(newInput).projectNamed(Pair.left(projects), Pair.right(projects), false) + .build(); + final RexNode newCondition; + if (frame.c != null) { + newCondition = adjustInputRefs(frame.c, mapNewInputToProjOutputs, newProject.getRowType()); + } else { + newCondition = null; + } + + // update mappings: + // oldInput ----> newInput + // + // newProject + // | + // oldInput ----> newInput + // + // is transformed to + // + // oldInput ----> newProject + // | + // newInput + + final Map combinedMap = new HashMap<>(); + final Map oldToNewOutputs = new HashMap<>(); + final List originalGrouping = rel.getGroupSet().toList(); + for (Integer oldInputPos : frame.oldToNewOutputs.keySet()) { + final Integer newIndex = + mapNewInputToProjOutputs.get(frame.oldToNewOutputs.get(oldInputPos)); + combinedMap.put(oldInputPos, newIndex); + // mapping grouping fields + if (originalGrouping.contains(oldInputPos)) { + oldToNewOutputs.put(oldInputPos, newIndex); + } + } + + // now it's time to rewrite the Aggregate + final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount); + final List newAggCalls = new ArrayList<>(); + final List oldAggCalls = rel.getAggCallList(); + + for (AggregateCall oldAggCall : oldAggCalls) { + final List oldAggArgs = oldAggCall.getArgList(); + final List aggArgs = new ArrayList<>(); + + // Adjust the Aggregate argument positions. + // Note Aggregate does not change input ordering, so the input + // output position mapping can be used to derive the new positions + // for the argument. + for (int oldPos : oldAggArgs) { + aggArgs.add(combinedMap.get(oldPos)); + } + final int filterArg = oldAggCall.filterArg < 0 + ? oldAggCall.filterArg + : combinedMap.get(oldAggCall.filterArg); + + newAggCalls.add( + oldAggCall.adaptTo(newProject, aggArgs, filterArg, oldGroupKeyCount, newGroupKeyCount)); + } + + relBuilder.push( + LogicalAggregate.create(newProject, rel.getHints(), newGroupSet, null, newAggCalls)); + + if (!omittedConstants.isEmpty()) { + final List postProjects = new ArrayList<>(relBuilder.fields()); + for (Map.Entry entry : omittedConstants.entrySet()) { + postProjects.add(mapNewInputToProjOutputs.get(entry.getKey()), entry.getValue()); + } + relBuilder.project(postProjects); + } + + // mapping aggCall output fields + for (int i = 0; i < oldAggCalls.size(); ++i) { + oldToNewOutputs.put(oldGroupKeyCount + i, newGroupKeyCount + omittedConstants.size() + i); + } + + // Aggregate does not change input ordering so corVars will be + // located at the same position as the input newProject. + return new Frame(rel, relBuilder.build(), newCondition, oldToNewOutputs); + } + + /** + * Rewrite LogicalJoin. + * + *

Rewrite logic: 1. rewrite join condition. 2. map output positions and produce corVars + * if any. + * + * @param rel Join + */ + public Frame decorrelateRel(LogicalJoin rel) { + final RelNode oldLeft = rel.getInput(0); + final RelNode oldRight = rel.getInput(1); + + final Frame leftFrame = getInvoke(oldLeft); + final Frame rightFrame = getInvoke(oldRight); + + if (leftFrame == null || rightFrame == null) { + // If any input has not been rewritten, do not rewrite this rel. + return null; + } + + switch (rel.getJoinType()) { + case LEFT: + assert rightFrame.c == null; + break; + case RIGHT: + assert leftFrame.c == null; + break; + case FULL: + assert leftFrame.c == null && rightFrame.c == null; + break; + default: + break; + } + + final int oldLeftFieldCount = oldLeft.getRowType().getFieldCount(); + final int newLeftFieldCount = leftFrame.r.getRowType().getFieldCount(); + final int oldRightFieldCount = oldRight.getRowType().getFieldCount(); + assert rel.getRowType().getFieldCount() == oldLeftFieldCount + oldRightFieldCount; + + final RexNode newJoinCondition = + adjustJoinCondition(rel.getCondition(), + oldLeftFieldCount, + newLeftFieldCount, + leftFrame.oldToNewOutputs, + rightFrame.oldToNewOutputs); + + final RelNode newJoin = + LogicalJoin.create(leftFrame.r, + rightFrame.r, + rel.getHints(), + newJoinCondition, + rel.getVariablesSet(), + rel.getJoinType()); + + // Create the mapping between the output of the old correlation rel and the new join rel + final Map mapOldToNewOutputs = new HashMap<>(); + // Left input positions are not changed. + mapOldToNewOutputs.putAll(leftFrame.oldToNewOutputs); + + // Right input positions are shifted by newLeftFieldCount. + for (int i = 0; i < oldRightFieldCount; i++) { + mapOldToNewOutputs.put(i + oldLeftFieldCount, + rightFrame.oldToNewOutputs.get(i) + newLeftFieldCount); + } + + final List corConditions = new ArrayList<>(); + if (leftFrame.c != null) { + corConditions.add(leftFrame.c); + } + if (rightFrame.c != null) { + // Right input positions are shifted by newLeftFieldCount. + final Map rightMapOldToNewOutputs = new HashMap<>(); + for (int index : rightFrame.getCorInputRefIndices()) { + rightMapOldToNewOutputs.put(index, index + newLeftFieldCount); + } + final RexNode newRightCondition = + adjustInputRefs(rightFrame.c, rightMapOldToNewOutputs, newJoin.getRowType()); + corConditions.add(newRightCondition); + } + + final RexNode newCondition = RexUtil.composeConjunction(rexBuilder, corConditions, true); + return new Frame(rel, newJoin, newCondition, mapOldToNewOutputs); + } + + private RexNode adjustJoinCondition( + final RexNode joinCondition, + final int oldLeftFieldCount, + final int newLeftFieldCount, + final Map leftOldToNewOutputs, + final Map rightOldToNewOutputs) { + return joinCondition.accept(new RexShuttle() { + + @Override public RexNode visitInputRef(RexInputRef inputRef) { + int oldIndex = inputRef.getIndex(); + final int newIndex; + if (oldIndex < oldLeftFieldCount) { + // field from left + assert leftOldToNewOutputs.containsKey(oldIndex); + newIndex = leftOldToNewOutputs.get(oldIndex); + } else { + // field from right + oldIndex = oldIndex - oldLeftFieldCount; + assert rightOldToNewOutputs.containsKey(oldIndex); + newIndex = rightOldToNewOutputs.get(oldIndex) + newLeftFieldCount; + } + return new RexInputRef(newIndex, inputRef.getType()); + } + }); + } + + /** + * Rewrite Sort. + * + *

Rewrite logic: change the collations field to reference the new input. + * + * @param rel Sort to be rewritten + */ + public Frame decorrelateRel(Sort rel) { + // Sort itself should not reference corVars. + assert !cm.mapRefRelToCorRef.containsKey(rel); + + // Sort only references field positions in collations field. + // The collations field in the newRel now need to refer to the + // new output positions in its input. + // Its output does not change the input ordering, so there's no + // need to call propagateExpr. + final RelNode oldInput = rel.getInput(); + final Frame frame = getInvoke(oldInput); + if (frame == null) { + // If input has not been rewritten, do not rewrite this rel. + return null; + } + final RelNode newInput = frame.r; + + Mappings.TargetMapping mapping = + Mappings.target(frame.oldToNewOutputs, + oldInput.getRowType().getFieldCount(), + newInput.getRowType().getFieldCount()); + + RelCollation oldCollation = rel.getCollation(); + RelCollation newCollation = RexUtil.apply(mapping, oldCollation); + + final RelNode newSort = LogicalSort.create(newInput, newCollation, rel.offset, rel.fetch) + .withHints(rel.getHints()); + + // Sort does not change input ordering + return new Frame(rel, newSort, frame.c, frame.oldToNewOutputs); + } + + /** + * Rewrites a {@link Values}. + * + * @param rel Values to be rewritten + */ + public Frame decorrelateRel(Values rel) { + // There are no inputs, so rel does not need to be changed. + return null; + } + + public Frame decorrelateRel(LogicalCorrelate rel) { + // does not allow correlation condition in its inputs now, so choose default behavior + return decorrelateRel((RelNode) rel); + } + + /** Fallback if none of the other {@code decorrelateRel} methods match. */ + public Frame decorrelateRel(RelNode rel) { + RelNode newRel = rel.copy(rel.getTraitSet(), rel.getInputs()); + if (rel.getInputs().size() > 0) { + List oldInputs = rel.getInputs(); + List newInputs = new ArrayList<>(); + for (int i = 0; i < oldInputs.size(); ++i) { + final Frame frame = getInvoke(oldInputs.get(i)); + if (frame == null || frame.c != null) { + // if input is not rewritten, or if it produces correlated variables, + // terminate rewrite + return null; + } + newInputs.add(frame.r); + newRel.replaceInput(i, frame.r); + } + + if (!Util.equalShallow(oldInputs, newInputs)) { + newRel = rel.copy(rel.getTraitSet(), newInputs); + } + if (rel instanceof Hintable) { + newRel = ((Hintable) newRel).withHints(((Hintable) rel).getHints()); + } + } + // the output position should not change since there are no corVars coming from below. + return new Frame(rel, newRel, null, identityMap(rel.getRowType().getFieldCount())); + } + + /* Returns an immutable map with the identity [0: 0, .., count-1: count-1]. */ + private static Map identityMap(int count) { + ImmutableMap.Builder builder = + ImmutableMap.builder(); + for (int i = 0; i < count; i++) { + builder.put(i, i); + } + return builder.build(); + } + + /** Returns a literal output field, or null if it is not literal. */ + private static RexLiteral projectedLiteral(RelNode rel, int i) { + if (rel instanceof Project) { + final Project project = (Project) rel; + final RexNode node = project.getProjects().get(i); + if (node instanceof RexLiteral) { + return (RexLiteral) node; + } + } + return null; + } + } + + /** Builds a {@link CorelMap}. */ + private static class CorelMapBuilder extends RelShuttleImpl { + + private final int maxCnfNodeCount; + // nested correlation variables in SubQuery, such as: + // SELECT * FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t1.a = t2.c AND + // t2.d IN (SELECT t3.d FROM t3 WHERE t1.b = t3.e) + boolean hasNestedCorScope = false; + // has unsupported correlation condition, such as: + // SELECT * FROM l WHERE a IN (SELECT c FROM r WHERE l.b IN (SELECT e FROM t)) + // SELECT a FROM l WHERE b IN (SELECT r1.e FROM r1 WHERE l.a = r1.d UNION SELECT r2.i FROM + // r2) + // SELECT * FROM l WHERE EXISTS (SELECT * FROM r LEFT JOIN (SELECT * FROM t WHERE t.j = l.b) + // t1 ON r.f = t1.k) + // SELECT * FROM l WHERE b IN (SELECT MIN(e) FROM r WHERE l.c > r.f) + // SELECT * FROM l WHERE b IN (SELECT MIN(e) OVER() FROM r WHERE l.c > r.f) + boolean hasUnsupportedCorCondition = false; + // true if SubQuery rel tree has Aggregate node, else false. + boolean hasAggregateNode = false; + // true if SubQuery rel tree has Over node, else false. + boolean hasOverNode = false; + + CorelMapBuilder(int maxCnfNodeCount) { + this.maxCnfNodeCount = maxCnfNodeCount; + } + + final SortedMap mapCorToCorRel = new TreeMap<>(); + final com.google.common.collect.SortedSetMultimap mapRefRelToCorRef = + com.google.common.collect.Multimaps + .newSortedSetMultimap(new HashMap>(), + new com.google.common.base.Supplier>() { + + public TreeSet get() { + Bug.upgrade("use MultimapBuilder when we're on Guava-16"); + return com.google.common.collect.Sets.newTreeSet(); + } + }); + final Map mapFieldAccessToCorVar = new HashMap<>(); + final Map> mapSubQueryNodeToCorSet = new HashMap<>(); + + int corrIdGenerator = 0; + final Deque corNodeStack = new ArrayDeque<>(); + + /** Creates a CorelMap by iterating over a {@link RelNode} tree. */ + CorelMap build(RelNode... rels) { + for (RelNode rel : rels) { + stripHep(rel).accept(this); + } + return CorelMap.of(mapRefRelToCorRef, mapCorToCorRel, mapSubQueryNodeToCorSet); + } + + @Override protected RelNode visitChild(RelNode parent, int i, RelNode input) { + return super.visitChild(parent, i, stripHep(input)); + } + + @Override public RelNode visit(LogicalCorrelate correlate) { + // TODO does not allow correlation condition in its inputs now + // If correlation conditions in correlate inputs reference to correlate outputs + // variable, + // that should not be supported, e.g. + // SELECT * FROM outer_table l WHERE l.c IN ( + // SELECT f1 FROM ( + // SELECT * FROM inner_table r WHERE r.d IN (SELECT x.i FROM x WHERE x.j = l.b)) t, + // LATERAL TABLE(table_func(t.f)) AS T(f1) + //)) + // other cases should be supported, e.g. + // SELECT * FROM outer_table l WHERE l.c IN ( + // SELECT f1 FROM ( + // SELECT * FROM inner_table r WHERE r.d IN (SELECT x.i FROM x WHERE x.j = r.e)) t, + // LATERAL TABLE(table_func(t.f)) AS T(f1) + //)) + checkCorConditionOfInput(correlate.getLeft()); + checkCorConditionOfInput(correlate.getRight()); + + visitChild(correlate, 0, correlate.getLeft()); + visitChild(correlate, 1, correlate.getRight()); + return correlate; + } + + @Override public RelNode visit(LogicalJoin join) { + switch (join.getJoinType()) { + case LEFT: + checkCorConditionOfInput(join.getRight()); + break; + case RIGHT: + checkCorConditionOfInput(join.getLeft()); + break; + case FULL: + checkCorConditionOfInput(join.getLeft()); + checkCorConditionOfInput(join.getRight()); + break; + default: + break; + } + + final boolean hasSubQuery = RexUtil.SubQueryFinder.find(join.getCondition()) != null; + try { + if (!corNodeStack.isEmpty()) { + mapSubQueryNodeToCorSet.put(join, corNodeStack.peek().getVariablesSet()); + } + if (hasSubQuery) { + corNodeStack.push(join); + } + checkCorCondition(join); + join.getCondition().accept(rexVisitor(join)); + } finally { + if (hasSubQuery) { + corNodeStack.pop(); + } + } + visitChild(join, 0, join.getLeft()); + visitChild(join, 1, join.getRight()); + return join; + } + + @Override public RelNode visit(LogicalFilter filter) { + final boolean hasSubQuery = RexUtil.SubQueryFinder.find(filter.getCondition()) != null; + try { + if (!corNodeStack.isEmpty()) { + mapSubQueryNodeToCorSet.put(filter, corNodeStack.peek().getVariablesSet()); + } + if (hasSubQuery) { + corNodeStack.push(filter); + } + checkCorCondition(filter); + filter.getCondition().accept(rexVisitor(filter)); + for (CorrelationId correlationId : filter.getVariablesSet()) { + mapCorToCorRel.put(correlationId, filter); + } + } finally { + if (hasSubQuery) { + corNodeStack.pop(); + } + } + return super.visit(filter); + } + + @Override public RelNode visit(LogicalProject project) { + hasOverNode = RexOver.containsOver(project.getProjects(), null); + final boolean hasSubQuery = RexUtil.SubQueryFinder.find(project.getProjects()) != null; + try { + if (!corNodeStack.isEmpty()) { + mapSubQueryNodeToCorSet.put(project, corNodeStack.peek().getVariablesSet()); + } + if (hasSubQuery) { + corNodeStack.push(project); + } + checkCorCondition(project); + for (RexNode node : project.getProjects()) { + node.accept(rexVisitor(project)); + } + } finally { + if (hasSubQuery) { + corNodeStack.pop(); + } + } + return super.visit(project); + } + + @Override public RelNode visit(LogicalAggregate aggregate) { + hasAggregateNode = true; + return super.visit(aggregate); + } + + @Override public RelNode visit(LogicalUnion union) { + checkCorConditionOfSetOpInputs(union); + return super.visit(union); + } + + @Override public RelNode visit(LogicalMinus minus) { + checkCorConditionOfSetOpInputs(minus); + return super.visit(minus); + } + + @Override public RelNode visit(LogicalIntersect intersect) { + checkCorConditionOfSetOpInputs(intersect); + return super.visit(intersect); + } + + /** + * check whether the predicate on filter has unsupported correlation condition. e.g. SELECT + * * FROM l WHERE a IN (SELECT c FROM r WHERE l.b = r.d OR r.d > 10) + */ + private void checkCorCondition(final LogicalFilter filter) { + if (mapSubQueryNodeToCorSet.containsKey(filter) && !hasUnsupportedCorCondition) { + final List corConditions = new ArrayList<>(); + final List unsupportedCorConditions = new ArrayList<>(); + analyzeCorConditions( + mapSubQueryNodeToCorSet.get(filter), + filter.getCondition(), + filter.getCluster().getRexBuilder(), + maxCnfNodeCount, + corConditions, + new ArrayList<>(), + unsupportedCorConditions); + if (!unsupportedCorConditions.isEmpty()) { + hasUnsupportedCorCondition = true; + } else if (!corConditions.isEmpty()) { + boolean hasNonEquals = false; + for (RexNode node : corConditions) { + if (node instanceof RexCall + && ((RexCall) node).getOperator() != SqlStdOperatorTable.EQUALS) { + hasNonEquals = true; + break; + } + } + // agg or over with non-equality correlation condition is unsupported, e.g. + // SELECT * FROM l WHERE b IN (SELECT MIN(e) FROM r WHERE l.c > r.f) + // SELECT * FROM l WHERE b IN (SELECT MIN(e) OVER() FROM r WHERE l.c > r.f) + hasUnsupportedCorCondition = hasNonEquals && (hasAggregateNode || hasOverNode); + } + } + } + + /** + * check whether the predicate on join has unsupported correlation condition. e.g. SELECT * + * FROM l WHERE a IN (SELECT c FROM r WHERE l.b IN (SELECT e FROM t)) + */ + private void checkCorCondition(final LogicalJoin join) { + if (!hasUnsupportedCorCondition) { + join.getCondition().accept(new RexVisitorImpl(true) { + + @Override public Void visitCorrelVariable(RexCorrelVariable correlVariable) { + hasUnsupportedCorCondition = true; + return super.visitCorrelVariable(correlVariable); + } + }); + } + } + + /** + * check whether the project has correlation expressions. e.g. SELECT * FROM l WHERE a IN + * (SELECT l.b FROM r) + */ + private void checkCorCondition(final LogicalProject project) { + if (!hasUnsupportedCorCondition) { + for (RexNode node : project.getProjects()) { + node.accept(new RexVisitorImpl(true) { + + @Override public Void visitCorrelVariable(RexCorrelVariable correlVariable) { + hasUnsupportedCorCondition = true; + return super.visitCorrelVariable(correlVariable); + } + }); + } + } + } + + /** + * check whether a node has some input which have correlation condition. e.g. SELECT * FROM + * l WHERE EXISTS (SELECT * FROM r LEFT JOIN (SELECT * FROM t WHERE t.j=l.b) t1 ON r.f=t1.k) + * the above sql can not be converted to semi-join plan, because the right input of + * Left-Join has the correlation condition(t.j=l.b). + */ + private void checkCorConditionOfInput(final RelNode input) { + final RelShuttleImpl shuttle = new RelShuttleImpl() { + + final RexVisitor visitor = new RexVisitorImpl(true) { + + @Override public Void visitCorrelVariable(RexCorrelVariable correlVariable) { + hasUnsupportedCorCondition = true; + return super.visitCorrelVariable(correlVariable); + } + }; + + @Override public RelNode visit(LogicalFilter filter) { + filter.getCondition().accept(visitor); + return super.visit(filter); + } + + @Override public RelNode visit(LogicalProject project) { + for (RexNode rex : project.getProjects()) { + rex.accept(visitor); + } + return super.visit(project); + } + + @Override public RelNode visit(LogicalJoin join) { + join.getCondition().accept(visitor); + return super.visit(join); + } + }; + input.accept(shuttle); + } + + /** + * check whether a SetOp has some children node which have correlation condition. e.g. + * SELECT a FROM l WHERE b IN (SELECT r1.e FROM r1 WHERE l.a = r1.d UNION SELECT r2.i FROM + * r2) + */ + private void checkCorConditionOfSetOpInputs(SetOp setOp) { + for (RelNode child : setOp.getInputs()) { + checkCorConditionOfInput(child); + } + } + + private RexVisitorImpl rexVisitor(final RelNode rel) { + return new RexVisitorImpl(true) { + + @Override public Void visitSubQuery(RexSubQuery subQuery) { + hasAggregateNode = false; // reset to default value + hasOverNode = false; // reset to default value + subQuery.rel.accept(CorelMapBuilder.this); + return super.visitSubQuery(subQuery); + } + + @Override public Void visitFieldAccess(RexFieldAccess fieldAccess) { + final RexNode ref = fieldAccess.getReferenceExpr(); + if (ref instanceof RexCorrelVariable) { + final RexCorrelVariable var = (RexCorrelVariable) ref; + // check the scope of correlation id + // we do not support nested correlation variables in SubQuery, such as: + // select * from t1 where exists (select * from t2 where t1.a = t2.c and + // t2.d in (select t3.d from t3 where t1.b = t3.e) + if (!hasUnsupportedCorCondition) { + hasUnsupportedCorCondition = !mapSubQueryNodeToCorSet.containsKey(rel); + } + if (!hasNestedCorScope && mapSubQueryNodeToCorSet.containsKey(rel)) { + hasNestedCorScope = !mapSubQueryNodeToCorSet.get(rel).contains(var.id); + } + + if (mapFieldAccessToCorVar.containsKey(fieldAccess)) { + // for cases where different Rel nodes are referring to + // same correlation var (e.g. in case of NOT IN) + // avoid generating another correlation var + // and record the 'rel' is using the same correlation + mapRefRelToCorRef.put(rel, mapFieldAccessToCorVar.get(fieldAccess)); + } else { + final CorRef correlation = + new CorRef(var.id, fieldAccess.getField().getIndex(), corrIdGenerator++); + mapFieldAccessToCorVar.put(fieldAccess, correlation); + mapRefRelToCorRef.put(rel, correlation); + } + } + return super.visitFieldAccess(fieldAccess); + } + }; + } + } + + /** + * A unique reference to a correlation field. + * + *

For instance, if a RelNode references emp.name multiple times, it would result in multiple + * {@code CorRef} objects that differ just in {@link CorRef#uniqueKey}. + */ + private static class CorRef implements Comparable { + + final int uniqueKey; + final CorrelationId corr; + final int field; + + CorRef(CorrelationId corr, int field, int uniqueKey) { + this.corr = corr; + this.field = field; + this.uniqueKey = uniqueKey; + } + + @Override public String toString() { + return corr.getName() + '.' + field; + } + + @Override public int hashCode() { + return Objects.hash(uniqueKey, corr, field); + } + + @Override public boolean equals(Object o) { + return this == o + || o instanceof CorRef + && uniqueKey == ((CorRef) o).uniqueKey + && corr == ((CorRef) o).corr + && field == ((CorRef) o).field; + } + + public int compareTo(CorRef o) { + int c = corr.compareTo(o.corr); + if (c != 0) { + return c; + } + c = Integer.compare(field, o.field); + if (c != 0) { + return c; + } + return Integer.compare(uniqueKey, o.uniqueKey); + } + } + + /** + * A map of the locations of correlation variables in a tree of {@link RelNode}s. + * + *

It is used to drive the decorrelation process. Treat it as immutable; rebuild if you + * modify the tree. + * + *

There are three maps: + * + *

    + *
  1. {@link #mapRefRelToCorRef} maps a {@link RelNode} to the correlated variables it + * references; + *
  2. {@link #mapCorToCorRel} maps a correlated variable to the {@link RelNode} providing it; + *
  3. {@link #mapSubQueryNodeToCorSet} maps a {@link RelNode} to the correlated variables it + * has; + *
+ */ + private static class CorelMap { + + private final com.google.common.collect.Multimap mapRefRelToCorRef; + private final SortedMap mapCorToCorRel; + private final Map> mapSubQueryNodeToCorSet; + + // TODO: create immutable copies of all maps + private CorelMap( + com.google.common.collect.Multimap mapRefRelToCorRef, + SortedMap mapCorToCorRel, + Map> mapSubQueryNodeToCorSet) { + this.mapRefRelToCorRef = mapRefRelToCorRef; + this.mapCorToCorRel = mapCorToCorRel; + this.mapSubQueryNodeToCorSet = + ImmutableMap.copyOf(mapSubQueryNodeToCorSet); + } + + @Override public String toString() { + return "mapRefRelToCorRef=" + + mapRefRelToCorRef + + "\nmapCorToCorRel=" + + mapCorToCorRel + + "\nmapSubQueryNodeToCorSet=" + + mapSubQueryNodeToCorSet + + "\n"; + } + + @Override public boolean equals(Object obj) { + return obj == this + || obj instanceof CorelMap + && mapRefRelToCorRef.equals(((CorelMap) obj).mapRefRelToCorRef) + && mapCorToCorRel.equals(((CorelMap) obj).mapCorToCorRel) + && mapSubQueryNodeToCorSet.equals(((CorelMap) obj).mapSubQueryNodeToCorSet); + } + + @Override public int hashCode() { + return Objects.hash(mapRefRelToCorRef, mapCorToCorRel, mapSubQueryNodeToCorSet); + } + + /** Creates a CorelMap with given contents. */ + public static CorelMap of( + com.google.common.collect.SortedSetMultimap mapRefRelToCorVar, + SortedMap mapCorToCorRel, + Map> mapSubQueryNodeToCorSet) { + return new CorelMap(mapRefRelToCorVar, mapCorToCorRel, mapSubQueryNodeToCorSet); + } + + /** + * Returns whether there are any correlating variables in this statement. + * + * @return whether there are any correlating variables + */ + boolean hasCorrelation() { + return !mapCorToCorRel.isEmpty(); + } + } + + /** + * Frame describing the relational expression after decorrelation and where to find the output + * fields and correlation condition. + */ + private static class Frame { + + // the new rel + final RelNode r; + // the condition contains correlation variables + final RexNode c; + // map the oldRel's field indices to newRel's field indices + final com.google.common.collect.ImmutableSortedMap oldToNewOutputs; + + Frame(RelNode oldRel, RelNode newRel, RexNode corCondition, + Map oldToNewOutputs) { + this.r = Objects.requireNonNull(newRel); + this.c = corCondition; + this.oldToNewOutputs = com.google.common.collect.ImmutableSortedMap.copyOf(oldToNewOutputs); + assert allLessThan(this.oldToNewOutputs.keySet(), oldRel.getRowType().getFieldCount(), + Litmus.THROW); + assert allLessThan(this.oldToNewOutputs.values(), r.getRowType().getFieldCount(), + Litmus.THROW); + } + + List getCorInputRefIndices() { + final List inputRefIndices; + if (c != null) { + inputRefIndices = RelOptUtil.InputFinder.bits(c).toList(); + } else { + inputRefIndices = new ArrayList<>(); + } + return inputRefIndices; + } + + private static boolean allLessThan(Collection integers, int limit, Litmus ret) { + for (int value : integers) { + if (value >= limit) { + return ret.fail("out of range; value: {}, limit: {}", value, limit); + } + } + return ret.succeed(); + } + } + + /** + * Result describing the relational expression after decorrelation and where to find the + * equivalent non-correlated expressions and correlated conditions. + */ + public static class Result { + + private final ImmutableMap> subQueryMap; + static final Result EMPTY = new Result(new HashMap<>()); + + private Result(Map> subQueryMap) { + this.subQueryMap = ImmutableMap.copyOf(subQueryMap); + } + + public Pair getSubQueryEquivalent(RexSubQuery subQuery) { + return subQueryMap.get(subQuery); + } + } +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java b/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java index ca5b016ae69..17ff6662dfa 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java @@ -20,6 +20,7 @@ import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelShuttleImpl; import org.apache.calcite.rel.core.Collect; import org.apache.calcite.rel.core.Correlate; import org.apache.calcite.rel.core.CorrelationId; @@ -27,9 +28,13 @@ import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.metadata.RelMdUtil; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rex.LogicVisitor; +import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexCorrelVariable; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; @@ -37,6 +42,7 @@ import org.apache.calcite.rex.RexShuttle; import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlQuantifyOperator; @@ -45,13 +51,17 @@ import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Util; import com.google.common.collect.ImmutableList; import org.immutables.value.Value; +import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Deque; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -866,7 +876,7 @@ private static void matchProject(SubQueryRemoveRule rule, project.getProjects(), e); builder.push(project.getInput()); final int fieldCount = builder.peek().getRowType().getFieldCount(); - final Set variablesSet = + final Set variablesSet = RelOptUtil.getVariablesUsed(e.rel); final RexNode target = rule.apply(e, variablesSet, logic, builder, 1, fieldCount, 0); @@ -877,7 +887,12 @@ private static void matchProject(SubQueryRemoveRule rule, } private static void matchFilter(SubQueryRemoveRule rule, - RelOptRuleCall call) { + RelOptRuleCall call, boolean semiJoinEnabled) { + if (semiJoinEnabled) { + if (rule.tryRewriteToSemiJoin(rule, call)) { + return; + } + } final Filter filter = call.rel(0); final Set filterVariablesSet = filter.getVariablesSet(); final RelBuilder builder = call.builder(); @@ -893,7 +908,7 @@ private static void matchFilter(SubQueryRemoveRule rule, ++count; final RelOptUtil.Logic logic = LogicVisitor.find(RelOptUtil.Logic.TRUE, ImmutableList.of(c), e); - final Set variablesSet = + final Set variablesSet = RelOptUtil.getVariablesUsed(e.rel); // Filter without variables could be handled before this change, we do not want // to break it yet for compatibility reason. @@ -912,6 +927,284 @@ private static void matchFilter(SubQueryRemoveRule rule, call.transformTo(builder.build()); } + private boolean tryRewriteToSemiJoin(SubQueryRemoveRule rule, RelOptRuleCall call) { + Filter filter = call.rel(0); + RexNode condition = filter.getCondition(); + + if (hasUnsupportedSubQuery(condition)) { + // has some unsupported subquery, such as: subquery connected with OR + // select * from t1 where t1.a > 10 or t1.b in (select t2.c from t2) + // TODO supports ExistenceJoin + return false; + } + + Optional subQueryCall = findSubQuery(condition); + if (subQueryCall.isEmpty()) { + // ignore scalar query + return false; + } + + SubQueryDecorrelator.Result decorrelate = SubQueryDecorrelator.decorrelateQuery(filter); + if (decorrelate == null) { + // can't handle the query + return false; + } + + RelBuilder relBuilder = call.builder(); + relBuilder.push(filter.getInput()); // push join left + + Optional newCondition = + rule.handleSubQuery(subQueryCall.get(), condition, relBuilder, decorrelate); + if (newCondition.isPresent()) { + RexNode c = newCondition.get(); + if (hasCorrelatedExpressions(c)) { + // some correlated expressions can not be replaced in this rule, + // so we must keep the VariablesSet for decorrelating later in new filter + // RelBuilder.filter can not create Filter with VariablesSet arg + Filter newFilter = filter.copy(filter.getTraitSet(), relBuilder.build(), c); + relBuilder.push(newFilter); + } else { + // all correlated expressions are replaced, + // so we can create a new filter without any VariablesSet + relBuilder.filter(c); + } + relBuilder.project(fields(relBuilder, filter.getRowType().getFieldCount())); + // the sub query has been replaced with a common node, + // so hints in it should also be resolved with the same logic in SqlToRelConverter + RelNode newNode = relBuilder.build(); + RelNode nodeWithHint = RelOptUtil.propagateRelHints(newNode, false); + // RelNode nodeWithCapitalizedJoinHints = FlinkHints.capitalizeJoinHints(nodeWithHint); + // RelNode finalNode = + // nodeWithCapitalizedJoinHints.accept(new ClearJoinHintWithInvalidPropagationShuttle()); + call.transformTo(newNode); + return true; + } else { + return false; + } + } + + private boolean hasCorrelatedExpressions(RexNode... nodes) { + final RelShuttleImpl relShuttle = new RelShuttleImpl() { + private final RexVisitorImpl corVarFinder = new RexVisitorImpl(true) { + @Override public Void visitCorrelVariable(RexCorrelVariable corVar) { + throw new Util.FoundOne(corVar); + } + }; + + @Override public RelNode visit(LogicalFilter filter) { + filter.getCondition().accept(corVarFinder); + return super.visit(filter); + } + + @Override public RelNode visit(LogicalJoin join) { + join.getCondition().accept(corVarFinder); + return super.visit(join); + } + + @Override public RelNode visit(LogicalProject project) { + for (RexNode node : project.getProjects()) { + node.accept(corVarFinder); + } + return super.visit(project); + } + }; + + final RexVisitorImpl subQueryFinder = new RexVisitorImpl(true) { + @Override public Void visitSubQuery(RexSubQuery subQuery) { + subQuery.rel.accept(relShuttle); + return null; + } + }; + + for (RexNode node : nodes) { + try { + node.accept(subQueryFinder); + } catch (Util.FoundOne e) { + return true; + } + } + return false; + } + + private boolean hasUnsupportedSubQuery(RexNode condition) { + try { + condition.accept(new RexVisitorImpl(true) { + + Deque stack = new ArrayDeque<>(); + + private void checkAndConjunctions(RexCall call) { + if (stack.stream().anyMatch(kind -> kind != SqlKind.AND)) { + throw new Util.FoundOne(call); + } + } + + @Override public Void visitSubQuery(RexSubQuery subQuery) { + if (!isScalarQuery(subQuery)) { + checkAndConjunctions(subQuery); + } + return null; + } + + @Override public Void visitCall(RexCall call) { + switch (call.getKind()) { + case NOT: + if (call.getOperands().get(0) instanceof RexSubQuery) { + // ignore scalar query + if (!isScalarQuery(call.getOperands().get(0))) { + checkAndConjunctions(call); + } + } + break; + default: + stack.push(call.getKind()); + super.visitCall(call); + stack.pop(); + } + return null; + } + }); + return false; + } catch (Util.FoundOne e) { + return true; + } + } + + + private Optional findSubQuery(RexNode node) { + try { + node.accept(new RexVisitorImpl(true) { + + @Override public Void visitSubQuery(RexSubQuery subQuery) { + if (!isScalarQuery(subQuery)) { + throw new Util.FoundOne(subQuery); + } + return null; + } + + @Override public Void visitCall(RexCall call) { + if (call.getKind() == SqlKind.NOT && call.getOperands().get(0) instanceof RexSubQuery) { + if (!isScalarQuery(call.getOperands().get(0))) { + throw new Util.FoundOne(call); + } + } + return super.visitCall(call); + } + }); + return Optional.empty(); + } catch (Util.FoundOne e) { + return Optional.of((RexCall) e.getNode()); + } + } + + private boolean isScalarQuery(RexNode n) { + return n.getKind() == SqlKind.SCALAR_QUERY; + } + + private Optional handleSubQuery( + RexCall subQueryCall, + RexNode condition, + RelBuilder relBuilder, + SubQueryDecorrelator.Result decorrelate) { + RelOptUtil.Logic logic = + LogicVisitor.find(RelOptUtil.Logic.TRUE, ImmutableList.of(condition), subQueryCall); + if (logic != RelOptUtil.Logic.TRUE) { + // this should not happen, none unsupported SubQuery could not reach here + // this is just for double-check + return Optional.empty(); + } + + Optional target = apply(subQueryCall, relBuilder, decorrelate); + if (!target.isPresent()) { + return Optional.empty(); + } + + RexNode newCondition = replaceSubQuery(condition, subQueryCall, target.get()); + Optional nextSubQueryCall = findSubQuery(newCondition); + return nextSubQueryCall.map( + subQuery -> handleSubQuery(subQuery, newCondition, relBuilder, decorrelate)) + .orElse(Optional.of(newCondition)); + } + + private Optional apply(RexCall subQueryCall, RelBuilder relBuilder, + SubQueryDecorrelator.Result decorrelate) { + + RexSubQuery subQuery; + boolean withNot = false; + if (subQueryCall instanceof RexSubQuery) { + subQuery = (RexSubQuery) subQueryCall; + } else if (subQueryCall.getOperands().get(0) instanceof RexSubQuery) { + subQuery = (RexSubQuery) subQueryCall.getOperands().get(0); + withNot = subQueryCall.getKind() == SqlKind.NOT; + } else { + return Optional.empty(); + } + + Pair equivalent = decorrelate.getSubQueryEquivalent(subQuery); + + switch (subQuery.getKind()) { +// case IN: +// return handleInSubQuery(subQuery, withNot, equivalent, relBuilder); + case EXISTS: + return handleExistsSubQuery(subQuery, withNot, equivalent, relBuilder); + default: + return Optional.empty(); + } + } + + private Optional handleExistsSubQuery( + RexSubQuery subQuery, + boolean withNot, + Pair equivalent, + RelBuilder relBuilder) { + RexNode joinCondition; + if (equivalent != null) { + // EXISTS has correlation variables + relBuilder.push(equivalent.getKey()); // push join right + joinCondition = equivalent.getValue(); + } else { + // Implement the logic for EXISTS and NOT EXISTS subqueries + int leftFieldCount = relBuilder.peek().getRowType().getFieldCount(); + relBuilder.push(subQuery.rel); // push join right + relBuilder.project(relBuilder.alias(relBuilder.literal(true), "i")); + relBuilder.aggregate(relBuilder.groupKey(), relBuilder.min("m", relBuilder.field(0))); + relBuilder.project(relBuilder.isNotNull(relBuilder.field(0))); + joinCondition = + new RexInputRef(leftFieldCount, + relBuilder.peek().getRowType().getFieldList().get(0).getType()); + } + + if (withNot) { + relBuilder.join(JoinRelType.ANTI, joinCondition); + } else { + relBuilder.join(JoinRelType.SEMI, joinCondition); + } + return Optional.of(relBuilder.literal(true)); + } + + + private static RexNode replaceSubQuery(RexNode condition, RexCall oldSubQueryCall, + RexNode replacement) { + return condition.accept(new RexShuttle() { + + @Override public RexNode visitSubQuery(RexSubQuery subQuery) { + if (oldSubQueryCall.equals(subQuery)) { + return replacement; + } + return subQuery; + } + + @Override public RexNode visitCall(RexCall call) { + if (call.getKind() == SqlKind.NOT && call.getOperands().get(0) instanceof RexSubQuery) { + if (oldSubQueryCall.equals(call)) { + return replacement; + } + } + return super.visitCall(call); + } + }); + } + + private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) { final Join join = call.rel(0); final RelBuilder builder = call.builder(); @@ -1006,6 +1299,7 @@ private static class ReplaceSubQueryShuttle extends RexShuttle { return subQuery.equals(this.subQuery) ? replacement : subQuery; } } + /** Rule configuration. */ @Value.Immutable(singleton = false) public interface Config extends RelRule.Config { @@ -1018,13 +1312,21 @@ public interface Config extends RelRule.Config { .withDescription("SubQueryRemoveRule:Project"); Config FILTER = ImmutableSubQueryRemoveRule.Config.builder() - .withMatchHandler(SubQueryRemoveRule::matchFilter) + .withMatchHandler((rule, call) -> SubQueryRemoveRule.matchFilter(rule, call, false)) .build() .withOperandSupplier(b -> b.operand(Filter.class) .predicate(RexUtil.SubQueryFinder::containsSubQuery).anyInputs()) .withDescription("SubQueryRemoveRule:Filter"); + Config FILTER_TO_SEMI = ImmutableSubQueryRemoveRule.Config.builder() + .withMatchHandler((rule, call) -> SubQueryRemoveRule.matchFilter(rule, call, true)) + .build() + .withOperandSupplier(b -> + b.operand(Filter.class) + .predicate(RexUtil.SubQueryFinder::containsSubQuery).anyInputs()) + .withDescription("SubQueryRemoveRule:Filter to semi"); + Config JOIN = ImmutableSubQueryRemoveRule.Config.builder() .withMatchHandler(SubQueryRemoveRule::matchJoin) .build() diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index d7fb23756ae..dcf3edbb3ac 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -100,6 +100,7 @@ import org.apache.calcite.rel.rules.SortProjectTransposeRule; import org.apache.calcite.rel.rules.SortUnionTransposeRule; import org.apache.calcite.rel.rules.SpatialRules; +import org.apache.calcite.rel.rules.SubQueryRemoveRule; import org.apache.calcite.rel.rules.UnionMergeRule; import org.apache.calcite.rel.rules.ValuesReduceRule; import org.apache.calcite.rel.type.RelDataType; @@ -8708,7 +8709,6 @@ private void checkSemiJoinRuleOnAntiJoin(RelOptRule rule) { .check(); } - @Disabled("[CALCITE-1045]") @Test void testExpandJoinIn() { final String sql = "select empno\n" + "from sales.emp left join sales.dept\n" @@ -8739,6 +8739,38 @@ private void checkSemiJoinRuleOnAntiJoin(RelOptRule rule) { sql(sql).withSubQueryRules().withLateDecorrelate(true).check(); } + @Test void testConvertExistsToSemiJoin() { + final String sql = "select * from sales.emp\n" + + "where EXISTS (\n" + + " select * from emp e where emp.deptno = e.deptno)"; + sql(sql).withExpand(false).withRule(SubQueryRemoveRule.Config.FILTER_TO_SEMI.toRule()).check(); + } + + @Test void testConvertNotExistsToAntiJoin() { + final String sql = "select * from sales.emp\n" + + "where NOT EXISTS (\n" + + " select * from emp e where emp.deptno = e.deptno)"; + sql(sql).withExpand(false).withRule(SubQueryRemoveRule.Config.FILTER_TO_SEMI.toRule()).check(); + } + + @Test void testConvertExistsWithUnsupportedCondition() { + final String sql = "select * from sales.emp\n" + + "where emp.empno > 0 OR EXISTS (\n" + + " select * from emp e where emp.deptno = e.deptno)"; + sql(sql).withExpand(false).withRule(SubQueryRemoveRule.Config.FILTER_TO_SEMI.toRule()).check(); + } + + @Test void testConvertExistsWithUnsupportedCorrelate() { + final String sql = "select * from sales.emp\n" + + "where EXISTS (\n" + + " select * from emp e where emp.deptno = e.deptno)\n" + + " and sal > (select avg(e2.sal) from emp e2 where emp.empno = e2.empno)"; + sql(sql).withExpand(false).withRule(SubQueryRemoveRule.Config.FILTER_TO_SEMI.toRule()) + .withLateDecorrelate(true).check(); + } + + // TODO test hasCorrelatedExpressions + /** Test case for * [CALCITE-1511] * AssertionError while decorrelating query with two EXISTS diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index cf7fba849d3..36988ee258d 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -1712,6 +1712,117 @@ LogicalProject(EXPR$0=[CASE(>($3, 0), $4, null:INTEGER)], SALARY=[$0], ENROLL_DA LogicalWindow(window#0=[window(order by [1] range between $2 PRECEDING and $2 FOLLOWING aggs [COUNT($0), SUM($0)])]) LogicalProject(EXPR$2=[$2], EXPR$3=[$3], $2=[*(365, 86400000:INTERVAL DAY)]) LogicalValues(tuples=[[{ 'x', 10, 5200, 2007-08-01 }, { null, null, null, null }]]) +]]> + + + + + + + + + + + + + + + + 0 OR EXISTS ( + select * from emp e where emp.deptno = e.deptno)]]> + + + ($0, 0), EXISTS({ +LogicalFilter(condition=[=($cor0.DEPTNO, $7)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +}))], variablesSet=[[$cor0]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + ($0, 0), IS NOT NULL($9))]) + LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{7}]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalAggregate(group=[{0}]) + LogicalProject(i=[true]) + LogicalFilter(condition=[=($cor0.DEPTNO, $7)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + (select avg(e2.sal) from emp e2 where emp.empno = e2.empno)]]> + + + ($5, $SCALAR_QUERY({ +LogicalAggregate(group=[{}], EXPR$0=[AVG($0)]) + LogicalProject(SAL=[$5]) + LogicalFilter(condition=[=($cor0.EMPNO, $0)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +})))], variablesSet=[[$cor0]]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + ($5, $9)]) + LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}]) + LogicalJoin(condition=[=($7, $9)], joinType=[semi]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalProject(DEPTNO=[$7]) + LogicalFilter(condition=[true]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalAggregate(group=[{}], EXPR$0=[AVG($0)]) + LogicalProject(SAL=[$5]) + LogicalFilter(condition=[=($cor0.EMPNO, $0)]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + ($5, $10))], joinType=[inner]) + LogicalJoin(condition=[=($7, $9)], joinType=[semi]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalProject(DEPTNO=[$7]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalAggregate(group=[{0}], EXPR$0=[AVG($1)]) + LogicalProject(EMPNO=[$0], SAL=[$5]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) ]]> @@ -1868,6 +1979,33 @@ MultiJoin(joinFilter=[true], isFullOuterJoin=[false], joinTypes=[[RIGHT, INNER]] LogicalTableScan(table=[[CATALOG, SALES, A]]) LogicalTableScan(table=[[CATALOG, SALES, B]]) LogicalTableScan(table=[[CATALOG, SALES, C]]) +]]> + + + + + + + + + + + @@ -4737,15 +4875,15 @@ LogicalProject(DEPTNO=[$7])