Skip to content

Commit 9bf83c7

Browse files
committed
Add optimizer to convert min_by/max_by to row number function
1 parent f419d2f commit 9bf83c7

File tree

6 files changed

+393
-0
lines changed

6 files changed

+393
-0
lines changed

presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ public final class SystemSessionProperties
330330
public static final String SINGLE_NODE_EXECUTION_ENABLED = "single_node_execution_enabled";
331331
public static final String EXPRESSION_OPTIMIZER_NAME = "expression_optimizer_name";
332332
public static final String ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID = "add_exchange_below_partial_aggregation_over_group_id";
333+
public static final String REWRITE_MIN_MAX_BY_TO_TOP_N = "rewrite_min_max_by_to_top_n";
333334

334335
// TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future.
335336
public static final String NATIVE_AGGREGATION_SPILL_ALL = "native_aggregation_spill_all";
@@ -1874,6 +1875,11 @@ public SystemSessionProperties(
18741875
"Enable single node execution",
18751876
featuresConfig.isSingleNodeExecutionEnabled(),
18761877
false),
1878+
booleanProperty(
1879+
REWRITE_MIN_MAX_BY_TO_TOP_N,
1880+
"rewrite min_by/max_by to top n",
1881+
true,
1882+
false),
18771883
booleanProperty(NATIVE_EXECUTION_SCALE_WRITER_THREADS_ENABLED,
18781884
"Enable automatic scaling of writer threads",
18791885
featuresConfig.isNativeExecutionScaleWritersThreadsEnabled(),
@@ -2342,6 +2348,11 @@ public static boolean isSingleNodeExecutionEnabled(Session session)
23422348
return session.getSystemProperty(SINGLE_NODE_EXECUTION_ENABLED, Boolean.class);
23432349
}
23442350

2351+
public static boolean isRewriteMinMaxByToTopNEnabled(Session session)
2352+
{
2353+
return session.getSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, Boolean.class);
2354+
}
2355+
23452356
public static boolean isPushAggregationThroughJoin(Session session)
23462357
{
23472358
return session.getSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, Boolean.class);

presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithSort;
6060
import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithTopN;
6161
import com.facebook.presto.sql.planner.iterative.rule.MergeLimits;
62+
import com.facebook.presto.sql.planner.iterative.rule.MinMaxByToWindowFunction;
6263
import com.facebook.presto.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct;
6364
import com.facebook.presto.sql.planner.iterative.rule.PickTableLayout;
6465
import com.facebook.presto.sql.planner.iterative.rule.PlanRemoteProjections;
@@ -655,6 +656,12 @@ public PlanOptimizers(
655656
ImmutableSet.of(new SimplifyCountOverConstant(metadata.getFunctionAndTypeManager()))),
656657
new LimitPushDown(), // Run LimitPushDown before WindowFilterPushDown
657658
new WindowFilterPushDown(metadata), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits
659+
new IterativeOptimizer(
660+
metadata,
661+
ruleStats,
662+
statsCalculator,
663+
estimatedExchangesCostCalculator,
664+
ImmutableSet.of(new MinMaxByToWindowFunction(metadata.getFunctionAndTypeManager()))),
658665
prefilterForLimitingAggregation,
659666
new IterativeOptimizer(
660667
metadata,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.sql.planner.iterative.rule;
15+
16+
import com.facebook.presto.Session;
17+
import com.facebook.presto.common.block.SortOrder;
18+
import com.facebook.presto.common.type.MapType;
19+
import com.facebook.presto.matching.Captures;
20+
import com.facebook.presto.matching.Pattern;
21+
import com.facebook.presto.metadata.FunctionAndTypeManager;
22+
import com.facebook.presto.spi.plan.AggregationNode;
23+
import com.facebook.presto.spi.plan.Assignments;
24+
import com.facebook.presto.spi.plan.DataOrganizationSpecification;
25+
import com.facebook.presto.spi.plan.FilterNode;
26+
import com.facebook.presto.spi.plan.Ordering;
27+
import com.facebook.presto.spi.plan.OrderingScheme;
28+
import com.facebook.presto.spi.plan.ProjectNode;
29+
import com.facebook.presto.spi.relation.ConstantExpression;
30+
import com.facebook.presto.spi.relation.RowExpression;
31+
import com.facebook.presto.spi.relation.VariableReferenceExpression;
32+
import com.facebook.presto.sql.planner.iterative.Rule;
33+
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
34+
import com.facebook.presto.sql.relational.FunctionResolution;
35+
import com.google.common.collect.ImmutableList;
36+
import com.google.common.collect.ImmutableMap;
37+
38+
import java.util.List;
39+
import java.util.Map;
40+
import java.util.Optional;
41+
42+
import static com.facebook.presto.SystemSessionProperties.isRewriteMinMaxByToTopNEnabled;
43+
import static com.facebook.presto.common.function.OperatorType.EQUAL;
44+
import static com.facebook.presto.common.type.BigintType.BIGINT;
45+
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
46+
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
47+
import static com.facebook.presto.sql.relational.Expressions.comparisonExpression;
48+
import static com.google.common.collect.ImmutableMap.toImmutableMap;
49+
50+
public class MinMaxByToWindowFunction
51+
implements Rule<AggregationNode>
52+
{
53+
private static final Pattern<AggregationNode> PATTERN = aggregation().matching(x -> !x.getHashVariable().isPresent() && !x.getGroupingKeys().isEmpty() && x.getGroupingSetCount() == 1 && x.getStep().equals(AggregationNode.Step.SINGLE));
54+
private final FunctionResolution functionResolution;
55+
56+
public MinMaxByToWindowFunction(FunctionAndTypeManager functionAndTypeManager)
57+
{
58+
this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
59+
}
60+
61+
@Override
62+
public boolean isEnabled(Session session)
63+
{
64+
return isRewriteMinMaxByToTopNEnabled(session);
65+
}
66+
67+
@Override
68+
public Pattern<AggregationNode> getPattern()
69+
{
70+
return PATTERN;
71+
}
72+
73+
@Override
74+
public Result apply(AggregationNode node, Captures captures, Context context)
75+
{
76+
Map<VariableReferenceExpression, AggregationNode.Aggregation> maxByAggregations = node.getAggregations().entrySet().stream()
77+
.filter(x -> functionResolution.isMaxByFunction(x.getValue().getFunctionHandle()) && x.getValue().getArguments().get(0).getType() instanceof MapType)
78+
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
79+
Map<VariableReferenceExpression, AggregationNode.Aggregation> minByAggregations = node.getAggregations().entrySet().stream()
80+
.filter(x -> functionResolution.isMinByFunction(x.getValue().getFunctionHandle()) && x.getValue().getArguments().get(0).getType() instanceof MapType)
81+
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
82+
boolean isMaxByAggregation;
83+
Map<VariableReferenceExpression, AggregationNode.Aggregation> candidateAggregation;
84+
if (maxByAggregations.isEmpty() && !minByAggregations.isEmpty()) {
85+
isMaxByAggregation = false;
86+
candidateAggregation = minByAggregations;
87+
}
88+
else if (!maxByAggregations.isEmpty() && minByAggregations.isEmpty()) {
89+
isMaxByAggregation = true;
90+
candidateAggregation = maxByAggregations;
91+
}
92+
else {
93+
return Result.empty();
94+
}
95+
boolean allMaxOrMinByWithSameField = candidateAggregation.values().stream().map(x -> x.getArguments().get(1)).distinct().count() == 1;
96+
if (!allMaxOrMinByWithSameField) {
97+
return Result.empty();
98+
}
99+
VariableReferenceExpression orderByVariable = (VariableReferenceExpression) candidateAggregation.values().stream().findFirst().get().getArguments().get(1);
100+
Map<VariableReferenceExpression, AggregationNode.Aggregation> remainingAggregations = node.getAggregations().entrySet().stream().filter(x -> !candidateAggregation.containsKey(x.getKey()))
101+
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
102+
boolean remainingEmptyOrMinOrMaxOnOrderBy = remainingAggregations.isEmpty() || (remainingAggregations.size() == 1
103+
&& remainingAggregations.values().stream().allMatch(x -> (isMaxByAggregation ? functionResolution.isMaxFunction(x.getFunctionHandle()) : functionResolution.isMinFunction(x.getFunctionHandle())) && x.getArguments().size() == 1 && x.getArguments().get(0).equals(orderByVariable)));
104+
if (!remainingEmptyOrMinOrMaxOnOrderBy) {
105+
return Result.empty();
106+
}
107+
108+
List<VariableReferenceExpression> partitionKeys = node.getGroupingKeys();
109+
OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(new Ordering(orderByVariable, isMaxByAggregation ? SortOrder.DESC_NULLS_LAST : SortOrder.ASC_NULLS_LAST)));
110+
DataOrganizationSpecification dataOrganizationSpecification = new DataOrganizationSpecification(partitionKeys, Optional.of(orderingScheme));
111+
VariableReferenceExpression rowNumberVariable = context.getVariableAllocator().newVariable("row_number", BIGINT);
112+
TopNRowNumberNode topNRowNumberNode =
113+
new TopNRowNumberNode(node.getSourceLocation(),
114+
context.getIdAllocator().getNextId(),
115+
node.getStatsEquivalentPlanNode(),
116+
node.getSource(),
117+
dataOrganizationSpecification,
118+
rowNumberVariable,
119+
1,
120+
false,
121+
Optional.empty());
122+
RowExpression equal = comparisonExpression(functionResolution, EQUAL, rowNumberVariable, new ConstantExpression(1L, BIGINT));
123+
FilterNode filterNode = new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getStatsEquivalentPlanNode(), topNRowNumberNode, equal);
124+
Map<VariableReferenceExpression, RowExpression> assignments = ImmutableMap.<VariableReferenceExpression, RowExpression>builder()
125+
.putAll(node.getAggregations().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, x -> x.getValue().getArguments().get(0)))).build();
126+
127+
ProjectNode projectNode = new ProjectNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getStatsEquivalentPlanNode(), filterNode,
128+
Assignments.builder().putAll(assignments).putAll(identityAssignments(node.getGroupingKeys())).build(), ProjectNode.Locality.LOCAL);
129+
return Result.ofPlanNode(projectNode);
130+
}
131+
}

presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,16 @@ public FunctionHandle countFunction(Type valueType)
324324
return functionAndTypeResolver.lookupFunction("count", fromTypes(valueType));
325325
}
326326

327+
public boolean isMaxByFunction(FunctionHandle functionHandle)
328+
{
329+
return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("max_by")));
330+
}
331+
332+
public boolean isMinByFunction(FunctionHandle functionHandle)
333+
{
334+
return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("min_by")));
335+
}
336+
327337
@Override
328338
public boolean isMaxFunction(FunctionHandle functionHandle)
329339
{

0 commit comments

Comments
 (0)