diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index f1e94ae0debd3..4c5627e629ee8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -331,6 +331,7 @@ public final class SystemSessionProperties public static final String EXPRESSION_OPTIMIZER_NAME = "expression_optimizer_name"; public static final String ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID = "add_exchange_below_partial_aggregation_over_group_id"; public static final String QUERY_CLIENT_TIMEOUT = "query_client_timeout"; + public static final String REWRITE_MIN_MAX_BY_TO_TOP_N = "rewrite_min_max_by_to_top_n"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_AGGREGATION_SPILL_ALL = "native_aggregation_spill_all"; @@ -1875,6 +1876,11 @@ public SystemSessionProperties( "Enable single node execution", featuresConfig.isSingleNodeExecutionEnabled(), false), + booleanProperty( + REWRITE_MIN_MAX_BY_TO_TOP_N, + "rewrite min_by/max_by to top n", + featuresConfig.isRewriteMinMaxByToTopNEnabled(), + false), booleanProperty(NATIVE_EXECUTION_SCALE_WRITER_THREADS_ENABLED, "Enable automatic scaling of writer threads", featuresConfig.isNativeExecutionScaleWritersThreadsEnabled(), @@ -2352,6 +2358,11 @@ public static boolean isSingleNodeExecutionEnabled(Session session) return session.getSystemProperty(SINGLE_NODE_EXECUTION_ENABLED, Boolean.class); } + public static boolean isRewriteMinMaxByToTopNEnabled(Session session) + { + return session.getSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, Boolean.class); + } + public static boolean isPushAggregationThroughJoin(Session session) { return session.getSystemProperty(PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN, Boolean.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 496e9b31ebaf9..9cd78bd997f4c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -295,6 +295,7 @@ public class FeaturesConfig private int eagerPlanValidationThreadPoolSize = 20; private boolean innerJoinPushdownEnabled; private boolean inEqualityJoinPushdownEnabled; + private boolean rewriteMinMaxByToTopNEnabled; private boolean prestoSparkExecutionEnvironment; private boolean singleNodeExecutionEnabled; @@ -2909,6 +2910,19 @@ public FeaturesConfig setInEqualityJoinPushdownEnabled(boolean inEqualityJoinPus return this; } + public boolean isRewriteMinMaxByToTopNEnabled() + { + return rewriteMinMaxByToTopNEnabled; + } + + @Config("optimizer.rewrite-minBy-maxBy-to-topN-enabled") + @ConfigDescription("Rewrite min_by and max_by to topN") + public FeaturesConfig setRewriteMinMaxByToTopNEnabled(boolean rewriteMinMaxByToTopNEnabled) + { + this.rewriteMinMaxByToTopNEnabled = rewriteMinMaxByToTopNEnabled; + return this; + } + public boolean isInEqualityJoinPushdownEnabled() { return inEqualityJoinPushdownEnabled; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 1da8a7acfe6dd..3dfca1a161312 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -59,6 +59,7 @@ import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithSort; import com.facebook.presto.sql.planner.iterative.rule.MergeLimitWithTopN; import com.facebook.presto.sql.planner.iterative.rule.MergeLimits; +import com.facebook.presto.sql.planner.iterative.rule.MinMaxByToWindowFunction; import com.facebook.presto.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct; import com.facebook.presto.sql.planner.iterative.rule.PickTableLayout; import com.facebook.presto.sql.planner.iterative.rule.PlanRemoteProjections; @@ -653,6 +654,12 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new SimplifyCountOverConstant(metadata.getFunctionAndTypeManager()))), + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new MinMaxByToWindowFunction(metadata.getFunctionAndTypeManager()))), new LimitPushDown(), // Run LimitPushDown before WindowFilterPushDown new WindowFilterPushDown(metadata), // This must run after PredicatePushDown and LimitPushDown so that it squashes any successive filter nodes and limits prefilterForLimitingAggregation, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MinMaxByToWindowFunction.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MinMaxByToWindowFunction.java new file mode 100644 index 0000000000000..c8acb0b6a833f --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MinMaxByToWindowFunction.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.common.type.MapType; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.isRewriteMinMaxByToTopNEnabled; +import static com.facebook.presto.common.function.OperatorType.EQUAL; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments; +import static com.facebook.presto.sql.planner.plan.Patterns.aggregation; +import static com.facebook.presto.sql.relational.Expressions.comparisonExpression; +import static com.google.common.collect.ImmutableMap.toImmutableMap; + +public class MinMaxByToWindowFunction + implements Rule +{ + private static final Pattern PATTERN = aggregation().matching(x -> !x.getHashVariable().isPresent() && !x.getGroupingKeys().isEmpty() && x.getGroupingSetCount() == 1 && x.getStep().equals(AggregationNode.Step.SINGLE)); + private final FunctionResolution functionResolution; + + public MinMaxByToWindowFunction(FunctionAndTypeManager functionAndTypeManager) + { + this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); + } + + @Override + public boolean isEnabled(Session session) + { + return isRewriteMinMaxByToTopNEnabled(session); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(AggregationNode node, Captures captures, Context context) + { + Map maxByAggregations = node.getAggregations().entrySet().stream() + .filter(x -> functionResolution.isMaxByFunction(x.getValue().getFunctionHandle()) && x.getValue().getArguments().get(0).getType() instanceof MapType) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + Map minByAggregations = node.getAggregations().entrySet().stream() + .filter(x -> functionResolution.isMinByFunction(x.getValue().getFunctionHandle()) && x.getValue().getArguments().get(0).getType() instanceof MapType) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + boolean isMaxByAggregation; + Map candidateAggregation; + if (maxByAggregations.isEmpty() && !minByAggregations.isEmpty()) { + isMaxByAggregation = false; + candidateAggregation = minByAggregations; + } + else if (!maxByAggregations.isEmpty() && minByAggregations.isEmpty()) { + isMaxByAggregation = true; + candidateAggregation = maxByAggregations; + } + else { + return Result.empty(); + } + boolean allMaxOrMinByWithSameField = candidateAggregation.values().stream().map(x -> x.getArguments().get(1)).distinct().count() == 1; + if (!allMaxOrMinByWithSameField) { + return Result.empty(); + } + VariableReferenceExpression orderByVariable = (VariableReferenceExpression) candidateAggregation.values().stream().findFirst().get().getArguments().get(1); + Map remainingAggregations = node.getAggregations().entrySet().stream().filter(x -> !candidateAggregation.containsKey(x.getKey())) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + boolean remainingEmptyOrMinOrMaxOnOrderBy = remainingAggregations.isEmpty() || (remainingAggregations.size() == 1 + && remainingAggregations.values().stream().allMatch(x -> (isMaxByAggregation ? functionResolution.isMaxFunction(x.getFunctionHandle()) : functionResolution.isMinFunction(x.getFunctionHandle())) && x.getArguments().size() == 1 && x.getArguments().get(0).equals(orderByVariable))); + if (!remainingEmptyOrMinOrMaxOnOrderBy) { + return Result.empty(); + } + + List partitionKeys = node.getGroupingKeys(); + OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(new Ordering(orderByVariable, isMaxByAggregation ? SortOrder.DESC_NULLS_LAST : SortOrder.ASC_NULLS_LAST))); + DataOrganizationSpecification dataOrganizationSpecification = new DataOrganizationSpecification(partitionKeys, Optional.of(orderingScheme)); + VariableReferenceExpression rowNumberVariable = context.getVariableAllocator().newVariable("row_number", BIGINT); + TopNRowNumberNode topNRowNumberNode = + new TopNRowNumberNode(node.getSourceLocation(), + context.getIdAllocator().getNextId(), + node.getStatsEquivalentPlanNode(), + node.getSource(), + dataOrganizationSpecification, + rowNumberVariable, + 1, + false, + Optional.empty()); + RowExpression equal = comparisonExpression(functionResolution, EQUAL, rowNumberVariable, new ConstantExpression(1L, BIGINT)); + FilterNode filterNode = new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getStatsEquivalentPlanNode(), topNRowNumberNode, equal); + Map assignments = ImmutableMap.builder() + .putAll(node.getAggregations().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, x -> x.getValue().getArguments().get(0)))).build(); + + ProjectNode projectNode = new ProjectNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getStatsEquivalentPlanNode(), filterNode, + Assignments.builder().putAll(assignments).putAll(identityAssignments(node.getGroupingKeys())).build(), ProjectNode.Locality.LOCAL); + return Result.ofPlanNode(projectNode); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java index 860d1e5ee3927..9281d2c363499 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/FunctionResolution.java @@ -324,6 +324,16 @@ public FunctionHandle countFunction(Type valueType) return functionAndTypeResolver.lookupFunction("count", fromTypes(valueType)); } + public boolean isMaxByFunction(FunctionHandle functionHandle) + { + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("max_by"))); + } + + public boolean isMinByFunction(FunctionHandle functionHandle) + { + return functionAndTypeResolver.getFunctionMetadata(functionHandle).getName().equals(functionAndTypeResolver.qualifyObjectName(QualifiedName.of("min_by"))); + } + @Override public boolean isMaxFunction(FunctionHandle functionHandle) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index f2f7bb9aa8843..ac534e726be16 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -259,6 +259,7 @@ public void testDefaults() .setAddExchangeBelowPartialAggregationOverGroupId(false) .setInnerJoinPushdownEnabled(false) .setInEqualityJoinPushdownEnabled(false) + .setRewriteMinMaxByToTopNEnabled(false) .setPrestoSparkExecutionEnvironment(false)); } @@ -458,6 +459,7 @@ public void testExplicitPropertyMappings() .put("eager-plan-validation-thread-pool-size", "2") .put("optimizer.inner-join-pushdown-enabled", "true") .put("optimizer.inequality-join-pushdown-enabled", "true") + .put("optimizer.rewrite-minBy-maxBy-to-topN-enabled", "true") .put("presto-spark-execution-environment", "true") .put("single-node-execution-enabled", "true") .put("native-execution-scale-writer-threads-enabled", "true") @@ -669,6 +671,7 @@ public void testExplicitPropertyMappings() .setExcludeInvalidWorkerSessionProperties(true) .setAddExchangeBelowPartialAggregationOverGroupId(true) .setInEqualityJoinPushdownEnabled(true) + .setRewriteMinMaxByToTopNEnabled(true) .setInnerJoinPushdownEnabled(true) .setPrestoSparkExecutionEnvironment(true); assertFullMapping(properties, expected); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMinMaxByToWindowFunction.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMinMaxByToWindowFunction.java new file mode 100644 index 0000000000000..6257756dfed13 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMinMaxByToWindowFunction.java @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.MapType; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.lang.invoke.MethodHandle; + +import static com.facebook.presto.SystemSessionProperties.REWRITE_MIN_MAX_BY_TO_TOP_N; +import static com.facebook.presto.common.block.MethodHandleUtil.compose; +import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueGetter; +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.block.SortOrder.DESC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.topNRowNumber; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.testing.TestingEnvironment.getOperatorMethodHandle; + +public class TestMinMaxByToWindowFunction + extends BaseRuleTest +{ + private static final MethodHandle KEY_NATIVE_EQUALS = getOperatorMethodHandle(OperatorType.EQUAL, BIGINT, BIGINT); + private static final MethodHandle KEY_BLOCK_EQUALS = compose(KEY_NATIVE_EQUALS, nativeValueGetter(BIGINT), nativeValueGetter(BIGINT)); + private static final MethodHandle KEY_NATIVE_HASH_CODE = getOperatorMethodHandle(OperatorType.HASH_CODE, BIGINT); + private static final MethodHandle KEY_BLOCK_HASH_CODE = compose(KEY_NATIVE_HASH_CODE, nativeValueGetter(BIGINT)); + + @Test + public void testMaxByOnly() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("max_by(a, ds)")) + .source( + p.values(ds, a, id))); + }) + .matches( + project( + filter( + topNRowNumber( + topNRowNumber -> topNRowNumber + .specification( + ImmutableList.of("id"), + ImmutableList.of("ds"), + ImmutableMap.of("ds", DESC_NULLS_LAST)) + .partial(false), + values("ds", "a", "id"))))); + } + + @Test + public void testMaxAndMaxBy() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("max_by(a, ds)")) + .addAggregation(p.variable("expr2"), p.rowExpression("max(ds)")) + .source( + p.values(ds, a, id))); + }) + .matches( + project( + filter( + topNRowNumber( + topNRowNumber -> topNRowNumber + .specification( + ImmutableList.of("id"), + ImmutableList.of("ds"), + ImmutableMap.of("ds", DESC_NULLS_LAST)) + .partial(false), + values("ds", "a", "id"))))); + } + + @Test + public void testMinByOnly() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("min_by(a, ds)")) + .source( + p.values(ds, a, id))); + }) + .matches( + project( + filter( + topNRowNumber( + topNRowNumber -> topNRowNumber + .specification( + ImmutableList.of("id"), + ImmutableList.of("ds"), + ImmutableMap.of("ds", ASC_NULLS_LAST)) + .partial(false), + values("ds", "a", "id"))))); + } + + @Test + public void testMinAndMinBy() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("min_by(a, ds)")) + .addAggregation(p.variable("expr2"), p.rowExpression("min(ds)")) + .source( + p.values(ds, a, id))); + }) + .matches( + project( + filter( + topNRowNumber( + topNRowNumber -> topNRowNumber + .specification( + ImmutableList.of("id"), + ImmutableList.of("ds"), + ImmutableMap.of("ds", ASC_NULLS_LAST)) + .partial(false), + values("ds", "a", "id"))))); + } + + @Test + public void testMinAndMaxBy() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", new MapType(BIGINT, BIGINT, KEY_BLOCK_EQUALS, KEY_BLOCK_HASH_CODE)); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("max_by(a, ds)")) + .addAggregation(p.variable("expr2"), p.rowExpression("min(ds)")) + .source( + p.values(ds, a, id))); + }).doesNotFire(); + } + + @Test + public void testMaxByOnlyNotOnMap() + { + tester().assertThat(new MinMaxByToWindowFunction(getFunctionManager())) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", VARCHAR); + VariableReferenceExpression ds = p.variable("ds", VARCHAR); + VariableReferenceExpression id = p.variable("id", BIGINT); + return p.aggregation(ap -> ap.singleGroupingSet(id).step(AggregationNode.Step.SINGLE) + .addAggregation(p.variable("expr"), p.rowExpression("max_by(a, ds)")) + .source( + p.values(ds, a, id))); + }).doesNotFire(); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index ad37f43f6a0c4..daad91109eb9b 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -89,6 +89,7 @@ import static com.facebook.presto.SystemSessionProperties.REWRITE_EXPRESSION_WITH_CONSTANT_EXPRESSION; import static com.facebook.presto.SystemSessionProperties.REWRITE_LEFT_JOIN_ARRAY_CONTAINS_TO_EQUI_JOIN; import static com.facebook.presto.SystemSessionProperties.REWRITE_LEFT_JOIN_NULL_FILTER_TO_SEMI_JOIN; +import static com.facebook.presto.SystemSessionProperties.REWRITE_MIN_MAX_BY_TO_TOP_N; import static com.facebook.presto.SystemSessionProperties.SIMPLIFY_PLAN_WITH_EMPTY_INPUT; import static com.facebook.presto.SystemSessionProperties.USE_DEFAULTS_FOR_CORRELATED_AGGREGATION_PUSHDOWN_THROUGH_OUTER_JOINS; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -8028,6 +8029,70 @@ public void testEvaluateProjectOnValues() "SELECT * FROM (VALUES (60, 29))"); } + @Test + public void testMinMaxByToWindowFunction() + { + Session enabled = Session.builder(getSession()) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "true") + .build(); + Session disabled = Session.builder(getSession()) + .setSystemProperty(REWRITE_MIN_MAX_BY_TO_TOP_N, "false") + .build(); + @Language("SQL") String sql = "with t as (SELECT * FROM ( VALUES (3, '2025-01-08', MAP(ARRAY[2, 1], ARRAY[0.34, 0.92])), (1, '2025-01-02', MAP(ARRAY[1, 3], ARRAY[0.23, 0.5])), " + + "(7, '2025-01-17', MAP(ARRAY[6, 8], ARRAY[0.60, 0.70])), (2, '2025-01-06', MAP(ARRAY[2, 3, 5, 7], ARRAY[0.75, 0.32, 0.19, 0.46])), " + + "(5, '2025-01-14', MAP(ARRAY[8, 4, 6], ARRAY[0.88, 0.99, 0.00])), (4, '2025-01-12', MAP(ARRAY[7, 3, 2], ARRAY[0.33, 0.44, 0.55])), " + + "(8, '2025-01-20', MAP(ARRAY[1, 7, 6], ARRAY[0.35, 0.45, 0.55])), (6, '2025-01-16', MAP(ARRAY[9, 1, 3], ARRAY[0.30, 0.40, 0.50])), " + + "(2, '2025-01-05', MAP(ARRAY[3, 4], ARRAY[0.98, 0.21])), (1, '2025-01-04', MAP(ARRAY[1, 2], ARRAY[0.45, 0.67])), (7, '2025-01-18', MAP(ARRAY[4, 2, 9], ARRAY[0.80, 0.90, 0.10])), " + + "(3, '2025-01-10', MAP(ARRAY[4, 1, 8, 6], ARRAY[0.85, 0.13, 0.42, 0.91])), (8, '2025-01-19', MAP(ARRAY[3, 5], ARRAY[0.15, 0.25])), " + + "(4, '2025-01-11', MAP(ARRAY[5, 6], ARRAY[0.11, 0.22])), (5, '2025-01-13', MAP(ARRAY[1, 9], ARRAY[0.66, 0.77])), (6, '2025-01-15', MAP(ARRAY[2, 5], ARRAY[0.10, 0.20])) ) " + + "t(id, ds, feature)) select id, max_by(feature, ds), max(ds) from t group by id"; + + MaterializedResult result = computeActual(enabled, "explain(type distributed) " + sql); + assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("TopNRowNumber"), -1); + + assertQueryWithSameQueryRunner(enabled, sql, disabled); + + sql = "with t as (SELECT * FROM ( VALUES (3, '2025-01-08', MAP(ARRAY[2, 1], ARRAY[0.34, 0.92]), MAP(ARRAY['a', 'b'], ARRAY[0.12, 0.88])), " + + "(1, '2025-01-02', MAP(ARRAY[1, 3], ARRAY[0.23, 0.5]), MAP(ARRAY['x', 'y'], ARRAY[0.45, 0.55])), (7, '2025-01-17', MAP(ARRAY[6, 8], ARRAY[0.60, 0.70]), MAP(ARRAY['m', 'n'], ARRAY[0.21, 0.79])), " + + "(2, '2025-01-06', MAP(ARRAY[2, 3, 5, 7], ARRAY[0.75, 0.32, 0.19, 0.46]), MAP(ARRAY['p', 'q', 'r'], ARRAY[0.11, 0.22, 0.67])), (5, '2025-01-14', MAP(ARRAY[8, 4, 6], ARRAY[0.88, 0.99, 0.00]), MAP(ARRAY['s', 't', 'u'], ARRAY[0.33, 0.44, 0.23])), " + + "(4, '2025-01-12', MAP(ARRAY[7, 3, 2], ARRAY[0.33, 0.44, 0.55]), MAP(ARRAY['v', 'w'], ARRAY[0.66, 0.34])), (8, '2025-01-20', MAP(ARRAY[1, 7, 6], ARRAY[0.35, 0.45, 0.55]), MAP(ARRAY['i', 'j', 'k'], ARRAY[0.78, 0.89, 0.12])), " + + "(6, '2025-01-16', MAP(ARRAY[9, 1, 3], ARRAY[0.30, 0.40, 0.50]), MAP(ARRAY['c', 'd'], ARRAY[0.90, 0.10])), (2, '2025-01-05', MAP(ARRAY[3, 4], ARRAY[0.98, 0.21]), MAP(ARRAY['e', 'f'], ARRAY[0.56, 0.44])), " + + "(1, '2025-01-04', MAP(ARRAY[1, 2], ARRAY[0.45, 0.67]), MAP(ARRAY['g', 'h'], ARRAY[0.23, 0.77])) ) t(id, ds, feature, extra_feature)) " + + "select id, max(ds), max_by(feature, ds), max_by(extra_feature, ds) from t group by id"; + + result = computeActual(enabled, "explain(type distributed) " + sql); + assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("TopNRowNumber"), -1); + + assertQueryWithSameQueryRunner(enabled, sql, disabled); + + sql = "with t as (SELECT * FROM ( VALUES (3, '2025-01-08', MAP(ARRAY[2, 1], ARRAY[0.34, 0.92])), (1, '2025-01-02', MAP(ARRAY[1, 3], ARRAY[0.23, 0.5])), " + + "(7, '2025-01-17', MAP(ARRAY[6, 8], ARRAY[0.60, 0.70])), (2, '2025-01-06', MAP(ARRAY[2, 3, 5, 7], ARRAY[0.75, 0.32, 0.19, 0.46])), " + + "(5, '2025-01-14', MAP(ARRAY[8, 4, 6], ARRAY[0.88, 0.99, 0.00])), (4, '2025-01-12', MAP(ARRAY[7, 3, 2], ARRAY[0.33, 0.44, 0.55])), " + + "(8, '2025-01-20', MAP(ARRAY[1, 7, 6], ARRAY[0.35, 0.45, 0.55])), (6, '2025-01-16', MAP(ARRAY[9, 1, 3], ARRAY[0.30, 0.40, 0.50])), " + + "(2, '2025-01-05', MAP(ARRAY[3, 4], ARRAY[0.98, 0.21])), (1, '2025-01-04', MAP(ARRAY[1, 2], ARRAY[0.45, 0.67])), (7, '2025-01-18', MAP(ARRAY[4, 2, 9], ARRAY[0.80, 0.90, 0.10])), " + + "(3, '2025-01-10', MAP(ARRAY[4, 1, 8, 6], ARRAY[0.85, 0.13, 0.42, 0.91])), (8, '2025-01-19', MAP(ARRAY[3, 5], ARRAY[0.15, 0.25])), " + + "(4, '2025-01-11', MAP(ARRAY[5, 6], ARRAY[0.11, 0.22])), (5, '2025-01-13', MAP(ARRAY[1, 9], ARRAY[0.66, 0.77])), (6, '2025-01-15', MAP(ARRAY[2, 5], ARRAY[0.10, 0.20])) ) " + + "t(id, ds, feature)) select id, min_by(feature, ds), min(ds) from t group by id"; + + result = computeActual(enabled, "explain(type distributed) " + sql); + assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("TopNRowNumber"), -1); + + assertQueryWithSameQueryRunner(enabled, sql, disabled); + + sql = "with t as (SELECT * FROM ( VALUES (3, '2025-01-08', MAP(ARRAY[2, 1], ARRAY[0.34, 0.92]), MAP(ARRAY['a', 'b'], ARRAY[0.12, 0.88])), " + + "(1, '2025-01-02', MAP(ARRAY[1, 3], ARRAY[0.23, 0.5]), MAP(ARRAY['x', 'y'], ARRAY[0.45, 0.55])), (7, '2025-01-17', MAP(ARRAY[6, 8], ARRAY[0.60, 0.70]), MAP(ARRAY['m', 'n'], ARRAY[0.21, 0.79])), " + + "(2, '2025-01-06', MAP(ARRAY[2, 3, 5, 7], ARRAY[0.75, 0.32, 0.19, 0.46]), MAP(ARRAY['p', 'q', 'r'], ARRAY[0.11, 0.22, 0.67])), (5, '2025-01-14', MAP(ARRAY[8, 4, 6], ARRAY[0.88, 0.99, 0.00]), MAP(ARRAY['s', 't', 'u'], ARRAY[0.33, 0.44, 0.23])), " + + "(4, '2025-01-12', MAP(ARRAY[7, 3, 2], ARRAY[0.33, 0.44, 0.55]), MAP(ARRAY['v', 'w'], ARRAY[0.66, 0.34])), (8, '2025-01-20', MAP(ARRAY[1, 7, 6], ARRAY[0.35, 0.45, 0.55]), MAP(ARRAY['i', 'j', 'k'], ARRAY[0.78, 0.89, 0.12])), " + + "(6, '2025-01-16', MAP(ARRAY[9, 1, 3], ARRAY[0.30, 0.40, 0.50]), MAP(ARRAY['c', 'd'], ARRAY[0.90, 0.10])), (2, '2025-01-05', MAP(ARRAY[3, 4], ARRAY[0.98, 0.21]), MAP(ARRAY['e', 'f'], ARRAY[0.56, 0.44])), " + + "(1, '2025-01-04', MAP(ARRAY[1, 2], ARRAY[0.45, 0.67]), MAP(ARRAY['g', 'h'], ARRAY[0.23, 0.77])) ) t(id, ds, feature, extra_feature)) " + + "select id, min(ds), min_by(feature, ds), min_by(extra_feature, ds) from t group by id"; + + result = computeActual(enabled, "explain(type distributed) " + sql); + assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("TopNRowNumber"), -1); + + assertQueryWithSameQueryRunner(enabled, sql, disabled); + } + private List getNativeWorkerSessionProperties(List inputRows, String sessionPropertyName) { return inputRows.stream()