Skip to content

Extend filtered aggregation optimizer to support only masked partial aggregation cases #25171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.isMergeAggregationsWithAndWithoutFilter;
import static com.facebook.presto.expressions.LogicalRowExpressions.or;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.plan.AggregationNode.Step.FINAL;
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
Expand Down Expand Up @@ -123,11 +127,13 @@ private static class Context
{
private final Map<VariableReferenceExpression, VariableReferenceExpression> partialResultToMask;
private final Map<VariableReferenceExpression, VariableReferenceExpression> partialOutputMapping;
private final List<VariableReferenceExpression> newAggregationOutput;

public Context()
{
partialResultToMask = new HashMap<>();
partialOutputMapping = new HashMap<>();
newAggregationOutput = new LinkedList<>();
}

public boolean isEmpty()
Expand All @@ -139,6 +145,7 @@ public void clear()
{
partialResultToMask.clear();
partialOutputMapping.clear();
newAggregationOutput.clear();
}

public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialOutputMapping()
Expand All @@ -150,6 +157,11 @@ public Map<VariableReferenceExpression, VariableReferenceExpression> getPartialR
{
return partialResultToMask;
}

public List<VariableReferenceExpression> getNewAggregationOutput()
{
return newAggregationOutput;
}
}

private static class Rewriter
Expand Down Expand Up @@ -218,17 +230,60 @@ else if (node.getStep().equals(FINAL)) {
private AggregationNode createPartialAggregationNode(AggregationNode node, PlanNode rewrittenSource, RewriteContext<Context> context)
{
checkState(context.get().isEmpty(), "There should be no partial aggregation left unmerged for a partial aggregation node");

Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsWithoutMaskToOutput = node.getAggregations().entrySet().stream()
.filter(x -> !x.getValue().getMask().isPresent())
.collect(toImmutableMap(x -> x.getValue(), x -> x.getKey(), (a, b) -> a));
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey, (a, b) -> a));
Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsToMergeOutput = node.getAggregations().entrySet().stream()
.filter(x -> x.getValue().getMask().isPresent() && aggregationsWithoutMaskToOutput.containsKey(removeFilterAndMask(x.getValue())))
.collect(toImmutableMap(x -> x.getValue(), x -> x.getKey()));
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));

ImmutableMap.Builder<AggregationNode.Aggregation, VariableReferenceExpression> partialAggregationToOutputBuilder = ImmutableMap.builder();
partialAggregationToOutputBuilder.putAll(aggregationsToMergeOutput.keySet().stream().collect(toImmutableMap(Function.identity(), x -> aggregationsWithoutMaskToOutput.get(removeFilterAndMask(x)))));

List<List<AggregationNode.Aggregation>> candidateAggregationsWithMaskNotMatched = node.getAggregations().entrySet().stream().map(Map.Entry::getValue)
.filter(x -> x.getMask().isPresent() && !aggregationsToMergeOutput.containsKey(x))
.collect(Collectors.groupingBy(AggregationNodeUtils::removeFilterAndMask)).values()
.stream().filter(x -> x.size() > 1).collect(toImmutableList());

Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsWithMaskToMerge = node.getAggregations().entrySet().stream()
.filter(x -> aggregationsToMergeOutput.containsKey(x.getValue()) || candidateAggregationsWithMaskNotMatched.stream().anyMatch(aggregations -> aggregations.contains(x.getValue())))
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> newMaskAssignmentsBuilder = ImmutableMap.builder();
ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsAddedBuilder = ImmutableMap.builder();
List<AggregationNode.Aggregation> newAggregationAdded = candidateAggregationsWithMaskNotMatched.stream()
.map(aggregations ->
{
List<VariableReferenceExpression> maskVariables = aggregations.stream().map(x -> x.getMask().get()).collect(toImmutableList());
RowExpression orMaskVariables = or(maskVariables);
VariableReferenceExpression newMaskVariable = variableAllocator.newVariable(orMaskVariables);
newMaskAssignmentsBuilder.put(newMaskVariable, orMaskVariables);
AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation(
aggregations.get(0).getCall(),
Optional.empty(),
aggregations.get(0).getOrderBy(),
aggregations.get(0).isDistinct(),
Optional.of(newMaskVariable));
VariableReferenceExpression newAggregationVariable = variableAllocator.newVariable(newAggregation.getCall());
aggregationsAddedBuilder.put(newAggregationVariable, newAggregation);
aggregations.forEach(x -> partialAggregationToOutputBuilder.put(x, newAggregationVariable));
return newAggregation;
})
.collect(toImmutableList());
Map<VariableReferenceExpression, RowExpression> newMaskAssignments = newMaskAssignmentsBuilder.build();
Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsAdded = aggregationsAddedBuilder.build();
Map<AggregationNode.Aggregation, VariableReferenceExpression> partialAggregationToOutput = partialAggregationToOutputBuilder.build();

Map<AggregationNode.Aggregation, VariableReferenceExpression> aggregationsToMergeOutputCombined =
node.getAggregations().entrySet().stream()
.filter(x -> x.getValue().getMask().isPresent() && aggregationsToMergeOutput.containsKey(x.getValue()) || candidateAggregationsWithMaskNotMatched.stream().anyMatch(aggregations -> aggregations.contains(x.getValue())))
.collect(toImmutableMap(Map.Entry::getValue, Map.Entry::getKey));

context.get().getPartialResultToMask().putAll(aggregationsToMergeOutput.entrySet().stream()
.collect(toImmutableMap(x -> x.getValue(), x -> x.getKey().getMask().get())));
context.get().getPartialOutputMapping().putAll(aggregationsToMergeOutput.entrySet().stream()
.collect(toImmutableMap(x -> x.getValue(), x -> aggregationsWithoutMaskToOutput.get(removeFilterAndMask(x.getKey())))));
context.get().getNewAggregationOutput().addAll(aggregationsAdded.keySet());
context.get().getPartialResultToMask().putAll(aggregationsWithMaskToMerge.entrySet().stream()
.collect(toImmutableMap(Map.Entry::getValue, x -> x.getKey().getMask().get())));
context.get().getPartialOutputMapping().putAll(aggregationsWithMaskToMerge.entrySet().stream()
.collect(toImmutableMap(Map.Entry::getValue, x -> partialAggregationToOutput.get(x.getKey()))));

Set<VariableReferenceExpression> maskVariables = new HashSet<>(context.get().getPartialResultToMask().values());
if (maskVariables.isEmpty()) {
Expand All @@ -242,14 +297,21 @@ private AggregationNode createPartialAggregationNode(AggregationNode node, PlanN
AggregationNode.GroupingSetDescriptor partialGroupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(
groupingVariables.build(), groupingSetDescriptor.getGroupingSetCount(), groupingSetDescriptor.getGlobalGroupingSets());

Set<VariableReferenceExpression> partialResultToMerge = new HashSet<>(aggregationsToMergeOutput.values());
Map<VariableReferenceExpression, AggregationNode.Aggregation> newAggregations = node.getAggregations().entrySet().stream()
Set<VariableReferenceExpression> partialResultToMerge = new HashSet<>(aggregationsToMergeOutputCombined.values());
Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsRemained = node.getAggregations().entrySet().stream()
.filter(x -> !partialResultToMerge.contains(x.getKey())).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
Map<VariableReferenceExpression, AggregationNode.Aggregation> newAggregations = ImmutableMap.<VariableReferenceExpression, AggregationNode.Aggregation>builder()
.putAll(aggregationsRemained).putAll(aggregationsAdded).build();

PlanNode newChild = rewrittenSource;
if (!newMaskAssignments.isEmpty()) {
newChild = addProjections(newChild, planNodeIdAllocator, newMaskAssignments);
}

return new AggregationNode(
node.getSourceLocation(),
node.getId(),
rewrittenSource,
newChild,
newAggregations,
partialGroupingSetDescriptor,
node.getPreGroupedVariables(),
Expand All @@ -265,7 +327,7 @@ private AggregationNode createFinalAggregationNode(AggregationNode node, PlanNod
return (AggregationNode) node.replaceChildren(ImmutableList.of(rewrittenSource));
}
List<VariableReferenceExpression> intermediateVariables = node.getAggregations().values().stream()
.map(x -> (VariableReferenceExpression) x.getArguments().get(0)).collect(Collectors.toList());
.map(x -> (VariableReferenceExpression) x.getArguments().get(0)).collect(toImmutableList());
checkState(intermediateVariables.containsAll(context.get().partialResultToMask.keySet()));

ImmutableList.Builder<RowExpression> projectionsFromPartialAgg = ImmutableList.builder();
Expand Down Expand Up @@ -331,6 +393,7 @@ public PlanNode visitProject(ProjectNode node, RewriteContext<Context> context)
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
assignments.putAll(excludeMergedAssignments);
assignments.putAll(identityAssignments(context.get().getPartialResultToMask().values()));
assignments.putAll(identityAssignments(context.get().getNewAggregationOutput()));
return new ProjectNode(
node.getSourceLocation(),
node.getId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,37 @@ public void testOptimizationApplied()
false);
}

@Test
public void testOptimizationAppliedAllHasMask()
{
assertPlan("SELECT partkey, sum(quantity) filter (where orderkey > 10), sum(quantity) filter (where orderkey > 0) from lineitem group by partkey",
enableOptimization(),
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum")),
Optional.of("maskFinalSum2"), functionCall("sum", ImmutableList.of("maskPartialSum2"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"),
"maskPartialSum2", expression("IF(expr2, partialSum, null)")),
anyTree(
aggregation(
singleGroupingSet("partkey", "expr", "expr2"),
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
ImmutableMap.of(new Symbol("partialSum"), new Symbol("expr_or")),
Optional.empty(),
AggregationNode.Step.PARTIAL,
project(
ImmutableMap.of("expr_or", expression("expr or expr2")),
project(
ImmutableMap.of("expr", expression("orderkey > 0"), "expr2", expression("orderkey >10")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity"))))))))),
false);
}

@Test
public void testOptimizationDisabled()
{
Expand Down Expand Up @@ -188,6 +219,57 @@ public void testAggregationsMultipleLevel()
false);
}

@Test
public void testAggregationsMultipleLevelAllAggWithMask()
{
assertPlan("select partkey, avg(sum) filter (where suppkey > 10), avg(sum) filter (where suppkey > 0), avg(filtersum) from (select partkey, suppkey, sum(quantity) filter (where orderkey > 10) sum, sum(quantity) filter (where orderkey > 0) filtersum from lineitem group by partkey, suppkey) t group by partkey",
enableOptimization(),
anyTree(
aggregation(
singleGroupingSet("partkey"),
ImmutableMap.of(Optional.of("finalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg_g10")), Optional.of("maskFinalAvg"), functionCall("avg", ImmutableList.of("maskPartialAvg")),
Optional.of("finalFilterAvg"), functionCall("avg", ImmutableList.of("partialFilterAvg"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialAvg", expression("IF(expr_2, partialAvg, null)"),
"maskPartialAvg_g10", expression("IF(expr_2_g10, partialAvg, null)")),
anyTree(
aggregation(
singleGroupingSet("partkey", "expr_2", "expr_2_g10"),
ImmutableMap.of(Optional.of("partialAvg"), functionCall("avg", ImmutableList.of("finalSum_g10")), Optional.of("partialFilterAvg"), functionCall("avg", ImmutableList.of("maskFinalSum"))),
ImmutableMap.of(new Symbol("partialAvg"), new Symbol("expr_2_or")),
Optional.empty(),
AggregationNode.Step.PARTIAL,
project(
ImmutableMap.of("expr_2_or", expression("expr_2 or expr_2_g10")),
project(
ImmutableMap.of("expr_2", expression("suppkey > 0"), "expr_2_g10", expression("suppkey > 10")),
aggregation(
singleGroupingSet("partkey", "suppkey"),
ImmutableMap.of(Optional.of("finalSum_g10"), functionCall("sum", ImmutableList.of("maskPartialSum_g10")), Optional.of("maskFinalSum"), functionCall("sum", ImmutableList.of("maskPartialSum"))),
ImmutableMap.of(),
Optional.empty(),
AggregationNode.Step.FINAL,
project(
ImmutableMap.of("maskPartialSum", expression("IF(expr, partialSum, null)"),
"maskPartialSum_g10", expression("IF(expr_g10, partialSum, null)")),
anyTree(
aggregation(
singleGroupingSet("partkey", "suppkey", "expr", "expr_g10"),
ImmutableMap.of(Optional.of("partialSum"), functionCall("sum", ImmutableList.of("quantity"))),
ImmutableMap.of(new Symbol("partialSum"), new Symbol("expr_or")),
Optional.empty(),
AggregationNode.Step.PARTIAL,
project(
ImmutableMap.of("expr_or", expression("expr or expr_g10")),
project(
ImmutableMap.of("expr", expression("orderkey > 0"), "expr_g10", expression("orderkey > 10")),
tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "partkey", "partkey", "quantity", "quantity", "suppkey", "suppkey"))))))))))))))),
false);
}

@Test
public void testGlobalOptimization()
{
Expand Down
Loading
Loading