Skip to content

Commit ec784bf

Browse files
emickleiantonmedv
authored andcommitted
fix folding of constant ints
1 parent 11be624 commit ec784bf

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed

expr_test.go

+20-9
Original file line numberDiff line numberDiff line change
@@ -963,21 +963,32 @@ func TestExpr_calls_with_nil(t *testing.T) {
963963
require.Equal(t, true, out)
964964
}
965965

966-
func TestExpr_call_floatarg_func_with_negative_int(t *testing.T) {
966+
func TestExpr_call_floatarg_func_with_int(t *testing.T) {
967967
env := map[string]interface{}{
968968
"cnv": func(f float64) interface{} {
969-
assert.Equal(t, -1, f)
970969
return f
971970
},
972971
}
973-
p, err := expr.Compile(
974-
"cnv(-1)",
975-
expr.Env(env))
976-
require.NoError(t, err)
972+
for _, each := range []struct {
973+
input string
974+
expected float64
975+
}{
976+
{"-1", -1.0},
977+
{"1+1", 2.0},
978+
{"+1", 1.0},
979+
{"1-1", 0.0},
980+
{"1/1", 1.0},
981+
{"1*1", 1.0},
982+
} {
983+
p, err := expr.Compile(
984+
fmt.Sprintf("cnv(%s)", each.input),
985+
expr.Env(env))
986+
require.NoError(t, err)
977987

978-
out, err := expr.Run(p, env)
979-
require.NoError(t, err)
980-
require.Equal(t, -1, out)
988+
out, err := expr.Run(p, env)
989+
require.NoError(t, err)
990+
require.Equal(t, each.expected, out)
991+
}
981992
}
982993

983994
func TestConstExpr_error(t *testing.T) {

optimizer/fold.go

+14-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package optimizer
22

33
import (
44
"math"
5+
"reflect"
56

67
. "github.com/antonmedv/expr/ast"
78
)
@@ -16,17 +17,24 @@ func (fold *fold) Exit(node *Node) {
1617
fold.applied = true
1718
Patch(node, newNode)
1819
}
20+
// for IntegerNode the type may have been changed from int->float
21+
// preserve this information by setting the type after the Patch
22+
patchWithType := func(newNode Node, leafType reflect.Type) {
23+
fold.applied = true
24+
Patch(node, newNode)
25+
newNode.SetType(leafType)
26+
}
1927

2028
switch n := (*node).(type) {
2129
case *UnaryNode:
2230
switch n.Operator {
2331
case "-":
2432
if i, ok := n.Node.(*IntegerNode); ok {
25-
patch(&IntegerNode{Value: -i.Value})
33+
patchWithType(&IntegerNode{Value: -i.Value}, n.Node.Type())
2634
}
2735
case "+":
2836
if i, ok := n.Node.(*IntegerNode); ok {
29-
patch(&IntegerNode{Value: i.Value})
37+
patchWithType(&IntegerNode{Value: i.Value}, n.Node.Type())
3038
}
3139
}
3240

@@ -35,7 +43,7 @@ func (fold *fold) Exit(node *Node) {
3543
case "+":
3644
if a, ok := n.Left.(*IntegerNode); ok {
3745
if b, ok := n.Right.(*IntegerNode); ok {
38-
patch(&IntegerNode{Value: a.Value + b.Value})
46+
patchWithType(&IntegerNode{Value: a.Value + b.Value}, a.Type())
3947
}
4048
}
4149
if a, ok := n.Left.(*StringNode); ok {
@@ -46,19 +54,19 @@ func (fold *fold) Exit(node *Node) {
4654
case "-":
4755
if a, ok := n.Left.(*IntegerNode); ok {
4856
if b, ok := n.Right.(*IntegerNode); ok {
49-
patch(&IntegerNode{Value: a.Value - b.Value})
57+
patchWithType(&IntegerNode{Value: a.Value - b.Value}, a.Type())
5058
}
5159
}
5260
case "*":
5361
if a, ok := n.Left.(*IntegerNode); ok {
5462
if b, ok := n.Right.(*IntegerNode); ok {
55-
patch(&IntegerNode{Value: a.Value * b.Value})
63+
patchWithType(&IntegerNode{Value: a.Value * b.Value}, a.Type())
5664
}
5765
}
5866
case "/":
5967
if a, ok := n.Left.(*IntegerNode); ok {
6068
if b, ok := n.Right.(*IntegerNode); ok {
61-
patch(&IntegerNode{Value: a.Value / b.Value})
69+
patchWithType(&IntegerNode{Value: a.Value / b.Value}, a.Type())
6270
}
6371
}
6472
case "%":

0 commit comments

Comments
 (0)