Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions pkg/planner/core/issuetest/planner_issue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,31 @@ ORDER BY field1`).Check(testkit.Rows())
tk.MustExec("rollback")
}

// issue-65965-rollup-alias-repeated-grouping-key
testkit.RunTestUnderCascades(t, func(t *testing.T, tk *testkit.TestKit, cascades, caller string) {
resetTestDB(t, tk)
tk.MustExec("set @@sql_mode = default")
tk.MustExec("create table t1 (a int, b int, c int)")
tk.MustExec("insert into t1 values (1, 2, 3), (4, 5, 6), (7, 8, 9)")

expected := testkit.Rows(
"1 2 1 3",
"1 2 <nil> 3",
"1 <nil> <nil> 3",
"4 5 4 6",
"4 5 <nil> 6",
"4 <nil> <nil> 6",
"7 8 7 9",
"7 8 <nil> 9",
"7 <nil> <nil> 9",
"<nil> <nil> <nil> 18",
)
tk.MustQuery("select /* issue:65965 */ a, b, a as d, sum(c) from t1 group by a, b, d with rollup").Sort().Check(expected)
tk.MustQuery("show warnings").Check(testkit.Rows())
tk.MustQuery("select /* issue:65965 */ a, b, a as d, sum(c) from t1 group by 1, 2, 3 with rollup").Sort().Check(expected)
tk.MustQuery("show warnings").Check(testkit.Rows())
})

// issue-67802-mutable-user-var-join-cond-should-not-become-inner-side-filter
testkit.RunTestUnderCascades(t, func(t *testing.T, tk *testkit.TestKit, cascades, caller string) {
resetTestDB(t, tk)
Expand Down
108 changes: 86 additions & 22 deletions pkg/planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package core

import (
"bytes"
"context"
"fmt"
"maps"
Expand Down Expand Up @@ -141,24 +142,34 @@ func (a *aggOrderByResolver) Leave(inNode ast.Node) (ast.Node, bool) {
return inNode, true
}

func (b *PlanBuilder) buildExpand(p base.LogicalPlan, gbyItems []expression.Expression) (base.LogicalPlan, []expression.Expression, error) {
func (b *PlanBuilder) buildExpand(p base.LogicalPlan, gbyItems []expression.Expression, gbyItemSourceFieldIndices []int) (base.LogicalPlan, []expression.Expression, error) {
ectx := p.SCtx().GetExprCtx().GetEvalCtx()
b.optFlag |= rule.FlagResolveExpand

// Rollup syntax require expand OP to do the data expansion, different data replica supply the different grouping layout.
distinctGbyExprs, gbyExprsRefPos := expression.DeduplicateGbyExpression(gbyItems)
// build another projection below.
proj := logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, p.Schema().Len()+len(distinctGbyExprs))}.Init(b.ctx, b.getSelectOffset())
// project: child's output and distinct GbyExprs in advance. (make every group-by item to be a column)
expandGbyExprs := gbyItems
expandGbySourceFieldIndices := gbyItemSourceFieldIndices
gbyExprsRefPos := make([]int, 0, len(gbyItems))
keepGbyItemPositions := needPreserveRollupGbyItemPositions(gbyItems, gbyItemSourceFieldIndices)
if !keepGbyItemPositions {
expandGbyExprs, gbyExprsRefPos = expression.DeduplicateGbyExpression(gbyItems)
expandGbySourceFieldIndices = deriveDeduplicatedGbySourceFieldIndices(gbyExprsRefPos, gbyItemSourceFieldIndices, len(expandGbyExprs))
}
Comment thread
hawkingrei marked this conversation as resolved.
Outdated

// Build another projection below. When a repeated GROUP BY item is introduced by a SELECT
// alias or ordinal, keep every original GROUP BY item position visible to Expand because
// those positions can have different ROLLUP output nullability.
proj := logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, p.Schema().Len()+len(expandGbyExprs))}.Init(b.ctx, b.getSelectOffset())
// project: child's output and GbyExprs in advance. (make every group-by item to be a column)
projSchema := p.Schema().Clone()
names := p.OutputNames()
for _, col := range projSchema.Columns {
proj.Exprs = append(proj.Exprs, col)
}
distinctGbyColNames := make(types.NameSlice, 0, len(distinctGbyExprs))
distinctGbyCols := make([]*expression.Column, 0, len(distinctGbyExprs))
for _, expr := range distinctGbyExprs {
// distinct group expr has been resolved in resolveGby.
distinctGbyColNames := make(types.NameSlice, 0, len(expandGbyExprs))
distinctGbyCols := make([]*expression.Column, 0, len(expandGbyExprs))
for _, expr := range expandGbyExprs {
// group expr has been resolved in resolveGby.
proj.Exprs = append(proj.Exprs, expr)

// add the newly appended names.
Expand All @@ -184,7 +195,15 @@ func (b *PlanBuilder) buildExpand(p base.LogicalPlan, gbyItems []expression.Expr
proj.SetSchema(projSchema)
proj.SetChildren(p)
proj.Proj4Expand = true
newGbyItems := expression.RestoreGbyExpression(distinctGbyCols, gbyExprsRefPos)
var newGbyItems []expression.Expression
if keepGbyItemPositions {
newGbyItems = make([]expression.Expression, 0, len(distinctGbyCols))
for _, col := range distinctGbyCols {
newGbyItems = append(newGbyItems, col.Clone())
}
} else {
newGbyItems = expression.RestoreGbyExpression(distinctGbyCols, gbyExprsRefPos)
}

// build expand.
rollupGroupingSets := expression.RollupGroupingSets(newGbyItems)
Expand All @@ -202,7 +221,8 @@ func (b *PlanBuilder) buildExpand(p base.LogicalPlan, gbyItems []expression.Expr
DistinctGroupByCol: distinctGbyCols,
DistinctGbyColNames: distinctGbyColNames,
// for resolving grouping function args.
DistinctGbyExprs: distinctGbyExprs,
DistinctGbyExprs: expandGbyExprs,
GbyItemSourceFieldIndices: expandGbySourceFieldIndices,

// fill the gen col names when building level projections.
}.Init(b.ctx, b.getSelectOffset())
Expand Down Expand Up @@ -252,6 +272,36 @@ func (b *PlanBuilder) buildExpand(p base.LogicalPlan, gbyItems []expression.Expr
return expand, newGbyItems, nil
}

func needPreserveRollupGbyItemPositions(gbyItems []expression.Expression, sourceFieldIndices []int) bool {
for i, sourceFieldIndex := range sourceFieldIndices {
if sourceFieldIndex < 0 {
continue
}
for j := range gbyItems {
if i == j {
continue
}
if bytes.Equal(gbyItems[i].CanonicalHashCode(), gbyItems[j].CanonicalHashCode()) {
return true
}
}
}
return false
}

func deriveDeduplicatedGbySourceFieldIndices(gbyExprsRefPos []int, sourceFieldIndices []int, distinctLen int) []int {
res := make([]int, distinctLen)
for i := range res {
res[i] = -1
}
for idx, refPos := range gbyExprsRefPos {
if res[refPos] == -1 {
res[refPos] = sourceFieldIndices[idx]
}
}
return res
}

func (b *PlanBuilder) buildAggregation(ctx context.Context, p base.LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression,
correlatedAggMap map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, map[int]int, error) {
b.optFlag |= rule.FlagBuildKeyInfo
Expand Down Expand Up @@ -1684,6 +1734,7 @@ func findColFromNaturalUsingJoin(p base.LogicalPlan, col *expression.Column) (na

type resolveGroupingTraverseAction struct {
CurrentBlockExpand *logicalop.LogicalExpand
SelectFieldIndex int
}

func (r resolveGroupingTraverseAction) Transform(expr expression.Expression) (res expression.Expression) {
Expand All @@ -1692,18 +1743,18 @@ func (r resolveGroupingTraverseAction) Transform(expr expression.Expression) (re
// when meeting a column, judge whether it's a relate grouping set col.
// eg: select a, b from t group by a, c with rollup, here a is, while b is not.
// in underlying Expand schema (a,b,c,a',c'), a select list should be resolved to a'.
res, _ = r.CurrentBlockExpand.TrySubstituteExprWithGroupingSetCol(x)
res, _ = r.CurrentBlockExpand.TrySubstituteExprWithGroupingSetColByFieldIndex(x, r.SelectFieldIndex)
case *expression.CorrelatedColumn:
// select 1 in (select t2.a from t group by t2.a, b with rollup) from t2;
// in this case: group by item has correlated column t2.a, and it's select list contains t2.a as well.
res, _ = r.CurrentBlockExpand.TrySubstituteExprWithGroupingSetCol(x)
res, _ = r.CurrentBlockExpand.TrySubstituteExprWithGroupingSetColByFieldIndex(x, r.SelectFieldIndex)
case *expression.Constant:
// constant just keep it real: select 1 from t group by a, b with rollup.
res = x
case *expression.ScalarFunction:
// scalar function just try to resolve itself first, then if not changed, trying resolve its children.
var substituted bool
res, substituted = r.CurrentBlockExpand.TrySubstituteExprWithGroupingSetCol(x)
res, substituted = r.CurrentBlockExpand.TrySubstituteExprWithGroupingSetColByFieldIndex(x, r.SelectFieldIndex)
if !substituted {
// if not changed, try to resolve it children.
// select a+1, grouping(b) from t group by a+1 (projected as c), b with rollup: in this case, a+1 is resolved as c as a whole.
Expand All @@ -1721,13 +1772,17 @@ func (r resolveGroupingTraverseAction) Transform(expr expression.Expression) (re
}

func (b *PlanBuilder) replaceGroupingFunc(expr expression.Expression) expression.Expression {
return b.replaceGroupingFuncByFieldIndex(expr, -1)
}

func (b *PlanBuilder) replaceGroupingFuncByFieldIndex(expr expression.Expression, fieldIndex int) expression.Expression {
// current block doesn't have an expand OP, just return it.
// expr can be nil when rewrite eliminates a predicate in non-scalar contexts.
if b.currentBlockExpand == nil || expr == nil {
return expr
}
// curExpand can supply the DistinctGbyExprs and gid col.
traverseAction := resolveGroupingTraverseAction{CurrentBlockExpand: b.currentBlockExpand}
traverseAction := resolveGroupingTraverseAction{CurrentBlockExpand: b.currentBlockExpand, SelectFieldIndex: fieldIndex}
return expr.Traverse(traverseAction)
}

Expand Down Expand Up @@ -1812,7 +1867,7 @@ func (b *PlanBuilder) buildProjection(ctx context.Context, p base.LogicalPlan, f
// for case: select a+1, b, sum(b), grouping(a) from t group by a, b with rollup.
// the column inside aggregate (only sum(b) here) should be resolved to original source column,
// while for others, just use expanded columns if exists: a'+ 1, b', group(gid)
newExpr = b.replaceGroupingFunc(newExpr)
newExpr = b.replaceGroupingFuncByFieldIndex(newExpr, i)

// For window functions in the order by clause, we will append an field for it.
// We need rewrite the window mapper here so order by clause could find the added field.
Expand Down Expand Up @@ -3371,7 +3426,8 @@ type gbyResolver struct {
isParam bool
skipAggMap map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn

exprDepth int // exprDepth is the depth of current expression in expression tree.
exprDepth int // exprDepth is the depth of current expression in expression tree.
sourceFieldIndex int
}

func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) {
Expand Down Expand Up @@ -3422,6 +3478,9 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) {
if isParam, ok := ret.(*driver.ParamMarkerExpr); ok {
isParam.UseAsValueInGbyByClause = true
}
if !g.inExpr {
g.sourceFieldIndex = index
}
return ret, true
}
}
Expand Down Expand Up @@ -3450,6 +3509,7 @@ func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) {
g.err = plannererrors.ErrWrongGroupField.GenWithStackByArgs(fieldName)
return inNode, false
}
g.sourceFieldIndex = pos - 1
return ret, true
case *ast.ValuesExpr:
if v.Column == nil {
Expand Down Expand Up @@ -4025,9 +4085,10 @@ func allColFromExprNode(p base.LogicalPlan, n ast.Node, names map[*types.FieldNa
// The returned `[]ast.Node` may differ from the original `gby.Items` in the group by clause for params. For params, the
// `gby.Items[].Expr` will not be overwritten. However, the resolved expression is still needed for further processing, so
// it's returned out.
func (b *PlanBuilder) resolveGbyExprs(p base.LogicalPlan, gby *ast.GroupByClause, fields []*ast.SelectField) ([]ast.ExprNode, error) {
func (b *PlanBuilder) resolveGbyExprs(p base.LogicalPlan, gby *ast.GroupByClause, fields []*ast.SelectField) ([]ast.ExprNode, []int, error) {
b.curClause = groupByClause
exprs := make([]ast.ExprNode, 0, len(gby.Items))
sourceFieldIndices := make([]int, 0, len(gby.Items))
schema := p.Schema()
names := p.OutputNames()
// findJoinFullSchema walks through transparent wrappers (LogicalSelection)
Expand All @@ -4048,17 +4109,19 @@ func (b *PlanBuilder) resolveGbyExprs(p base.LogicalPlan, gby *ast.GroupByClause
resolver.inExpr = false
resolver.exprDepth = 0
resolver.isParam = false
resolver.sourceFieldIndex = -1
retExpr, _ := item.Expr.Accept(resolver)
if resolver.err != nil {
return exprs, errors.Trace(resolver.err)
return exprs, sourceFieldIndices, errors.Trace(resolver.err)
}
if !resolver.isParam {
item.Expr = retExpr.(ast.ExprNode)
}

exprs = append(exprs, retExpr.(ast.ExprNode))
sourceFieldIndices = append(sourceFieldIndices, resolver.sourceFieldIndex)
}
return exprs, nil
return exprs, sourceFieldIndices, nil
}

func (b *PlanBuilder) rewriteGbyExprs(ctx context.Context, p base.LogicalPlan, gby *ast.GroupByClause, items []ast.ExprNode) (base.LogicalPlan, []expression.Expression, bool, error) {
Expand Down Expand Up @@ -4357,8 +4420,9 @@ func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p b
}

var gbyExprs []ast.ExprNode
var gbyItemSourceFieldIndices []int
if sel.GroupBy != nil {
gbyExprs, err = b.resolveGbyExprs(p, sel.GroupBy, sel.Fields.Fields)
gbyExprs, gbyItemSourceFieldIndices, err = b.resolveGbyExprs(p, sel.GroupBy, sel.Fields.Fields)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -4502,7 +4566,7 @@ func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p b
if needBuildAgg {
// if rollup syntax is specified, Expand OP is required to replicate the data to feed different grouping layout.
if rollup {
p, gbyCols, err = b.buildExpand(p, gbyCols)
p, gbyCols, err = b.buildExpand(p, gbyCols, gbyItemSourceFieldIndices)
if err != nil {
return nil, err
}
Expand Down
27 changes: 27 additions & 0 deletions pkg/planner/core/operator/logicalop/logical_expand.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ type LogicalExpand struct {
// keep the old gbyExprs for resolve cases like grouping(a+b), the args:
// a+b should be resolved to new projected gby col according to ref pos.
DistinctGbyExprs []expression.Expression `hash64-equals:"true"`
// GbyItemSourceFieldIndices records which SELECT field produced each GROUP BY item.
// -1 means the GROUP BY item was not resolved from a SELECT field alias or ordinal.
GbyItemSourceFieldIndices []int

// rollup grouping sets.
DistinctSize int `hash64-equals:"true"`
Expand Down Expand Up @@ -300,6 +303,30 @@ func (p *LogicalExpand) TrySubstituteExprWithGroupingSetCol(expr expression.Expr
return expr, false
}

// TrySubstituteExprWithGroupingSetColByFieldIndex is like TrySubstituteExprWithGroupingSetCol,
// but it first preserves SELECT-field to GROUP-BY-item bindings produced by aliases or ordinals.
func (p *LogicalExpand) TrySubstituteExprWithGroupingSetColByFieldIndex(expr expression.Expression, fieldIndex int) (expression.Expression, bool) {
if fieldIndex >= 0 {
for i, sourceFieldIndex := range p.GbyItemSourceFieldIndices {
if sourceFieldIndex != fieldIndex {
continue
}
if bytes.Equal(expr.CanonicalHashCode(), p.DistinctGbyExprs[i].CanonicalHashCode()) {
return p.DistinctGroupByCol[i], true
}
}
for i, sourceFieldIndex := range p.GbyItemSourceFieldIndices {
if sourceFieldIndex != -1 {
continue
}
if bytes.Equal(expr.CanonicalHashCode(), p.DistinctGbyExprs[i].CanonicalHashCode()) {
return p.DistinctGroupByCol[i], true
}
}
}
return p.TrySubstituteExprWithGroupingSetCol(expr)
}

// ResolveGroupingFuncArgsInGroupBy checks whether grouping function args is in grouping items.
func (p *LogicalExpand) ResolveGroupingFuncArgsInGroupBy(groupingFuncArgs []expression.Expression) ([]*expression.Column, error) {
// build GBYColMap
Expand Down
Loading