Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,9 @@ func (b *baseBuiltinFunc) cloneFrom(from *baseBuiltinFunc) {
for _, arg := range from.args {
b.args = append(b.args, arg.Clone())
}
b.tp = from.tp
if from.tp != nil {
b.tp = from.tp.Clone()
}
b.pbCode = from.pbCode
b.childrenVectorizedOnce = new(sync.Once)
if from.ctor != nil {
Expand Down
33 changes: 28 additions & 5 deletions pkg/expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,27 @@ func FoldConstant(ctx BuildContext, expr Expression) Expression {
return e
}

// cloneFoldedBranchWithRetType copies branches selected by special folding
// before FoldConstant applies the parent expression metadata to the returned expression.
func cloneFoldedBranchWithRetType(expr Expression) Expression {
switch e := expr.(type) {
case *Column:
cloned := e.Clone().(*Column)
cloned.RetType = e.RetType.Clone()
return cloned
case *CorrelatedColumn:
cloned := e.Clone().(*CorrelatedColumn)
cloned.RetType = e.RetType.Clone()
return cloned
case *ScalarFunction:
cloned := e.Clone().(*ScalarFunction)
cloned.RetType = cloned.Function.getRetTp()
return cloned
default:
return expr
}
}

func isNullHandler(ctx BuildContext, expr *ScalarFunction) (Expression, bool) {
arg0 := expr.GetArgs()[0]
if constArg, isConst := arg0.(*Constant); isConst {
Expand Down Expand Up @@ -85,9 +106,11 @@ func ifFoldHandler(ctx BuildContext, expr *ScalarFunction) (Expression, bool) {
return expr, false
}
if !isNull0 && arg0 != 0 {
return foldConstant(ctx, args[1])
foldedExpr, isDeferred := foldConstant(ctx, args[1])
return cloneFoldedBranchWithRetType(foldedExpr), isDeferred
}
return foldConstant(ctx, args[2])
foldedExpr, isDeferred := foldConstant(ctx, args[2])
return cloneFoldedBranchWithRetType(foldedExpr), isDeferred
}
// if the condition is not const, which branch is unknown to run, so directly return.
return expr, false
Expand All @@ -109,7 +132,7 @@ func ifNullFoldHandler(ctx BuildContext, expr *ScalarFunction) (Expression, bool
expr.GetType(ctx.GetEvalCtx()).SetCharset(args[1].GetType(ctx.GetEvalCtx()).GetCharset())
expr.GetType(ctx.GetEvalCtx()).SetCollate(args[1].GetType(ctx.GetEvalCtx()).GetCollate())

return foldedExpr, isConstant
return cloneFoldedBranchWithRetType(foldedExpr), isConstant
}
return constArg, isDeferred
}
Expand Down Expand Up @@ -141,7 +164,7 @@ func caseWhenHandler(ctx BuildContext, expr *ScalarFunction) (Expression, bool)
foldedExpr.GetType(ctx.GetEvalCtx()).SetDecimal(expr.GetType(ctx.GetEvalCtx()).GetDecimal())
return foldedExpr, isDeferredConst
}
return foldedExpr, isDeferredConst
return cloneFoldedBranchWithRetType(foldedExpr), isDeferredConst
}
}
// If the number of arguments in casewhen is odd, and the previous conditions
Expand All @@ -154,7 +177,7 @@ func caseWhenHandler(ctx BuildContext, expr *ScalarFunction) (Expression, bool)
foldedExpr.GetType(ctx.GetEvalCtx()).SetDecimal(expr.GetType(ctx.GetEvalCtx()).GetDecimal())
return foldedExpr, isDeferredConst
}
return foldedExpr, isDeferredConst
return cloneFoldedBranchWithRetType(foldedExpr), isDeferredConst
}
return expr, isDeferredConst
}
Expand Down
37 changes: 37 additions & 0 deletions pkg/expression/constant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,43 @@ func TestConstantFolding(t *testing.T) {
newConds := FoldConstant(ctx, expr)
require.Equalf(t, tt.result, newConds.StringWithCtx(ctx.GetEvalCtx(), errors.RedactLogDisable), "different for expr %s", tt.condition)
}

ctx := mock.NewContext().GetExprCtx()
col := newColumnWithType(0, types.NewFieldTypeWithCollation(mysql.TypeVarchar, "utf8mb4_bin", 255))
caseRetType := types.NewFieldTypeWithCollation(mysql.TypeVarchar, "binary", 255)
caseExpr := newFunctionWithType(ctx, ast.Case, caseRetType,
newLonglong(1),
col,
newString("", "binary"),
)
caseExpr.GetType(ctx.GetEvalCtx()).SetCharset("binary")
caseExpr.GetType(ctx.GetEvalCtx()).SetCollate("binary")
folded := FoldConstant(ctx, caseExpr)
require.IsType(t, &Column{}, folded)
require.NotSame(t, col, folded)
require.Equal(t, "utf8mb4", col.RetType.GetCharset())
require.Equal(t, "utf8mb4_bin", col.RetType.GetCollate())
require.Equal(t, "binary", folded.GetType(ctx.GetEvalCtx()).GetCharset())
require.Equal(t, "binary", folded.GetType(ctx.GetEvalCtx()).GetCollate())

lowerExpr := newFunctionWithType(ctx, ast.Lower, types.NewFieldType(mysql.TypeVarchar), col).(*ScalarFunction)
clonedLowerExpr := lowerExpr.Clone().(*ScalarFunction)
require.NotSame(t, lowerExpr.RetType, clonedLowerExpr.RetType)
clonedLowerExpr.GetType(ctx.GetEvalCtx()).SetCharset("binary")
require.Equal(t, "utf8mb4", lowerExpr.GetType(ctx.GetEvalCtx()).GetCharset())
require.Equal(t, "utf8mb4", clonedLowerExpr.Function.getRetTp().GetCharset())

ifRetType := types.NewFieldTypeWithCollation(mysql.TypeVarchar, "binary", 255)
ifExpr := newFunctionWithType(ctx, ast.If, ifRetType,
newLonglong(1),
lowerExpr,
newString("", "binary"),
)
foldedLowerExpr := FoldConstant(ctx, ifExpr).(*ScalarFunction)
require.NotSame(t, lowerExpr, foldedLowerExpr)
require.Same(t, foldedLowerExpr.RetType, foldedLowerExpr.Function.getRetTp())
require.Equal(t, "utf8mb4", lowerExpr.GetType(ctx.GetEvalCtx()).GetCharset())
require.Equal(t, "binary", foldedLowerExpr.Function.getRetTp().GetCharset())
}

func TestConstantFoldingCharsetConvert(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,12 @@ func ScalarFuncs2Exprs(funcs []*ScalarFunction) []Expression {

// Clone implements Expression interface.
func (sf *ScalarFunction) Clone() Expression {
function := sf.Function.Clone()
c := &ScalarFunction{
FuncName: sf.FuncName,
RetType: sf.RetType,
Function: sf.Function.Clone(),
RetType: sf.RetType.Clone(),
Function: function,
}
// An implicit assumption: ScalarFunc.RetType == ScalarFunc.builtinFunc.RetType
if sf.canonicalhashcode != nil {
c.canonicalhashcode = slices.Clone(sf.canonicalhashcode)
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2410,4 +2410,19 @@ from (
from t0
group by t0.c1, t0.c0, t0.c2
) as s where ref3`).Check(testkit.Rows())

tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(c0 varchar(4) character set utf8mb4 collate utf8mb4_bin)")
tk.MustExec("insert into t1 values ('0'), ('袦')")
tk.MustQuery(`select /* issue:68053 direct */ concat(t1.c0, '#', reverse(t1.c0)) from t1
where true or (case false when 'mU*' then t1.c0 when t1.c0 then (char(t1.c0) = 1) else binary(true) end)
order by 1`).Check(testkit.Rows("0#0", "袦#袦"))
tk.MustQuery(`select /* issue:68053 derived */ concat(ref0, '#', reverse(ref0)) from (
select t1.c0 as ref0,
(true or (case false when 'mU*' then t1.c0 when t1.c0 then (char(t1.c0) = 1) else binary(true) end)) as ref1
from t1
) as s where ref1 order by 1`).Check(testkit.Rows("0#0", "袦#袦"))
tk.MustQuery("select distinct charset(c0), charset(reverse(c0)) from t1").
Check(testkit.Rows("utf8mb4 utf8mb4"))
tk.MustQuery("select hex(reverse('袦'))").Check(testkit.Rows("E8A2A6"))
}
Original file line number Diff line number Diff line change
Expand Up @@ -344,4 +344,3 @@ select * from t force index(bx) where concat(a, a) = 'aaaaaaaaaa';
explain format = 'plan_tree' select * from t force index(bx) where concat(a, a) = 'aaaaaaaaaa';
drop table t;
set @@tidb_enable_unsafe_substitute=0;

Loading