Skip to content

Commit 596f54f

Browse files
authored
Invoke the Deref function as needed for the function arguments. (#651)
1 parent c6c7227 commit 596f54f

File tree

3 files changed

+109
-8
lines changed

3 files changed

+109
-8
lines changed

checker/checker.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ func (v *checker) checkArguments(
10441044
continue
10451045
}
10461046

1047-
if !t.AssignableTo(in) && kind(t) != reflect.Interface {
1047+
if !(t.AssignableTo(in) || deref.Type(t).AssignableTo(in)) && kind(t) != reflect.Interface {
10481048
return anyType, &file.Error{
10491049
Location: arg.Location(),
10501050
Message: fmt.Sprintf("cannot use %v as argument (type %v) to call %v ", t, in, name),

compiler/compiler.go

+41-7
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,8 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) {
592592
}
593593

594594
func (c *compiler) equalBinaryNode(node *ast.BinaryNode) {
595-
l := kind(node.Left)
596-
r := kind(node.Right)
595+
l := kind(node.Left.Type())
596+
r := kind(node.Right.Type())
597597

598598
leftIsSimple := isSimpleType(node.Left)
599599
rightIsSimple := isSimpleType(node.Right)
@@ -727,9 +727,44 @@ func (c *compiler) SliceNode(node *ast.SliceNode) {
727727
}
728728

729729
func (c *compiler) CallNode(node *ast.CallNode) {
730-
for _, arg := range node.Arguments {
731-
c.compile(arg)
730+
fn := node.Callee.Type()
731+
if kind(fn) == reflect.Func {
732+
fnInOffset := 0
733+
fnNumIn := fn.NumIn()
734+
switch callee := node.Callee.(type) {
735+
case *ast.MemberNode:
736+
if prop, ok := callee.Property.(*ast.StringNode); ok {
737+
if _, ok = callee.Node.Type().MethodByName(prop.Value); ok && callee.Node.Type().Kind() != reflect.Interface {
738+
fnInOffset = 1
739+
fnNumIn--
740+
}
741+
}
742+
case *ast.IdentifierNode:
743+
if t, ok := c.config.Types[callee.Value]; ok && t.Method {
744+
fnInOffset = 1
745+
fnNumIn--
746+
}
747+
}
748+
for i, arg := range node.Arguments {
749+
c.compile(arg)
750+
if k := kind(arg.Type()); k == reflect.Ptr || k == reflect.Interface {
751+
var in reflect.Type
752+
if fn.IsVariadic() && i >= fnNumIn-1 {
753+
in = fn.In(fn.NumIn() - 1).Elem()
754+
} else {
755+
in = fn.In(i + fnInOffset)
756+
}
757+
if k = kind(in); k != reflect.Ptr && k != reflect.Interface {
758+
c.emit(OpDeref)
759+
}
760+
}
761+
}
762+
} else {
763+
for _, arg := range node.Arguments {
764+
c.compile(arg)
765+
}
732766
}
767+
733768
if ident, ok := node.Callee.(*ast.IdentifierNode); ok {
734769
if c.config != nil {
735770
if fn, ok := c.config.Functions[ident.Value]; ok {
@@ -1162,7 +1197,7 @@ func (c *compiler) PairNode(node *ast.PairNode) {
11621197
}
11631198

11641199
func (c *compiler) derefInNeeded(node ast.Node) {
1165-
switch kind(node) {
1200+
switch kind(node.Type()) {
11661201
case reflect.Ptr, reflect.Interface:
11671202
c.emit(OpDeref)
11681203
}
@@ -1181,8 +1216,7 @@ func (c *compiler) optimize() {
11811216
}
11821217
}
11831218

1184-
func kind(node ast.Node) reflect.Kind {
1185-
t := node.Type()
1219+
func kind(t reflect.Type) reflect.Kind {
11861220
if t == nil {
11871221
return reflect.Invalid
11881222
}

test/deref/deref_test.go

+67
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package deref_test
33
import (
44
"context"
55
"testing"
6+
"time"
67

78
"github.com/expr-lang/expr/internal/testify/assert"
89
"github.com/expr-lang/expr/internal/testify/require"
@@ -253,3 +254,69 @@ func TestDeref_fetch_from_interface_mix_pointer(t *testing.T) {
253254
assert.NoError(t, err)
254255
assert.Equal(t, "waldo", res)
255256
}
257+
258+
func TestDeref_func_args(t *testing.T) {
259+
i := 20
260+
env := map[string]any{
261+
"var": &i,
262+
"fn": func(p int) int {
263+
return p + 1
264+
},
265+
}
266+
267+
program, err := expr.Compile(`fn(var) + fn(var + 0)`, expr.Env(env))
268+
require.NoError(t, err)
269+
270+
out, err := expr.Run(program, env)
271+
require.NoError(t, err)
272+
require.Equal(t, 42, out)
273+
}
274+
275+
func TestDeref_struct_func_args(t *testing.T) {
276+
n, _ := time.Parse(time.RFC3339, "2024-05-12T18:30:00+00:00")
277+
duration := 30 * time.Minute
278+
env := map[string]any{
279+
"time": n,
280+
"duration": &duration,
281+
}
282+
283+
program, err := expr.Compile(`time.Add(duration).Format('2006-01-02T15:04:05Z07:00')`, expr.Env(env))
284+
require.NoError(t, err)
285+
286+
out, err := expr.Run(program, env)
287+
require.NoError(t, err)
288+
require.Equal(t, "2024-05-12T19:00:00Z", out)
289+
}
290+
291+
func TestDeref_ignore_func_args(t *testing.T) {
292+
f := foo(1)
293+
env := map[string]any{
294+
"foo": &f,
295+
"fn": func(f *foo) int {
296+
return f.Bar()
297+
},
298+
}
299+
300+
program, err := expr.Compile(`fn(foo)`, expr.Env(env))
301+
require.NoError(t, err)
302+
303+
out, err := expr.Run(program, env)
304+
require.NoError(t, err)
305+
require.Equal(t, 42, out)
306+
}
307+
308+
func TestDeref_ignore_struct_func_args(t *testing.T) {
309+
n := time.Now()
310+
location, _ := time.LoadLocation("UTC")
311+
env := map[string]any{
312+
"time": n,
313+
"location": location,
314+
}
315+
316+
program, err := expr.Compile(`time.In(location).Location().String()`, expr.Env(env))
317+
require.NoError(t, err)
318+
319+
out, err := expr.Run(program, env)
320+
require.NoError(t, err)
321+
require.Equal(t, "UTC", out)
322+
}

0 commit comments

Comments
 (0)