Skip to content

Commit 85c6b10

Browse files
committed
Add support for struct's methods, improve ~ operator types, fix ?: operator return type
1 parent 1dc7443 commit 85c6b10

File tree

6 files changed

+171
-8
lines changed

6 files changed

+171
-8
lines changed

eval.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ func (n methodNode) Eval(env interface{}) (interface{}, error) {
268268
return nil, err
269269
}
270270

271-
method, ok := extract(v, n.method)
271+
method, ok := getFunc(v, n.method)
272272
if !ok {
273273
return nil, fmt.Errorf("cannot get method %v from %T: %v", n.method, v, n)
274274
}
@@ -324,7 +324,7 @@ func (n builtinNode) Eval(env interface{}) (interface{}, error) {
324324
}
325325

326326
func (n functionNode) Eval(env interface{}) (interface{}, error) {
327-
fn, ok := extract(env, n.name)
327+
fn, ok := getFunc(env, n.name)
328328
if !ok {
329329
return nil, fmt.Errorf("undefined: %v", n.name)
330330
}

eval_test.go

+57-1
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ func TestEval_complex(t *testing.T) {
564564

565565
expected := true
566566
if !reflect.DeepEqual(actual, expected) {
567-
t.Fatalf("TestEvalComplex:\ngot\n\t%#v\nexpected:\n\t%#v", actual, expected)
567+
t.Fatalf("TestEval_complex:\ngot\n\t%#v\nexpected:\n\t%#v", actual, expected)
568568
}
569569
}
570570

@@ -584,3 +584,59 @@ func TestEval_panic(t *testing.T) {
584584
t.Errorf("\ngot\n\t%+v\nexpected\n\t%v", err.Error(), expected)
585585
}
586586
}
587+
588+
func TestEval_method(t *testing.T) {
589+
env := testEnv{
590+
Hello: "hello",
591+
World: testWorld{
592+
name: []string{"w", "o", "r", "l", "d"},
593+
},
594+
testVersion: testVersion{
595+
version: 2,
596+
},
597+
}
598+
599+
input := `Title(Hello) ~ ' ' ~ (CompareVersion(1, 3) ? World.String() : '')`
600+
601+
node, err := expr.Parse(input)
602+
fmt.Printf("%#v\n", node)
603+
if err != nil {
604+
t.Fatal(err)
605+
}
606+
607+
actual, err := expr.Run(node, env)
608+
if err != nil {
609+
t.Fatal(err)
610+
}
611+
612+
expected := "Hello world"
613+
if !reflect.DeepEqual(actual, expected) {
614+
t.Fatalf("TestEval_method:\ngot\n\t%#v\nexpected:\n\t%#v", actual, expected)
615+
}
616+
}
617+
618+
type testVersion struct {
619+
version float64
620+
}
621+
622+
func (c testVersion) CompareVersion(min, max float64) bool {
623+
return min < c.version && c.version < max
624+
}
625+
626+
type testWorld struct {
627+
name []string
628+
}
629+
630+
func (w testWorld) String() string {
631+
return strings.Join(w.name, "")
632+
}
633+
634+
type testEnv struct {
635+
testVersion
636+
Hello string
637+
World testWorld
638+
}
639+
640+
func (e testEnv) Title(s string) string {
641+
return strings.Title(s)
642+
}

parser.go

+5
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,11 @@ func (p *parser) fromStruct(t reflect.Type) typesTable {
181181

182182
types[f.Name] = f.Type
183183
}
184+
185+
for i := 0; i < t.NumMethod(); i++ {
186+
m := t.Method(i)
187+
types[m.Name] = m.Type
188+
}
184189
}
185190

186191
return types

runtime.go

+27
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,33 @@ func extract(val interface{}, i interface{}) (interface{}, bool) {
101101
return nil, false
102102
}
103103

104+
func getFunc(val interface{}, i interface{}) (interface{}, bool) {
105+
v := reflect.ValueOf(val)
106+
switch v.Kind() {
107+
case reflect.Map:
108+
value := v.MapIndex(reflect.ValueOf(i))
109+
if value.IsValid() && value.CanInterface() {
110+
return value.Interface(), true
111+
}
112+
case reflect.Struct:
113+
name := reflect.ValueOf(i).String()
114+
method := v.MethodByName(name)
115+
if method.IsValid() && method.CanInterface() {
116+
return method.Interface(), true
117+
}
118+
value := v.FieldByName(name)
119+
if value.IsValid() && value.CanInterface() {
120+
return value.Interface(), true
121+
}
122+
case reflect.Ptr:
123+
value := v.Elem()
124+
if value.IsValid() && value.CanInterface() {
125+
return getFunc(value.Interface(), i)
126+
}
127+
}
128+
return nil, false
129+
}
130+
104131
func contains(needle interface{}, array interface{}) (bool, error) {
105132
if array != nil {
106133
v := reflect.ValueOf(array)

type.go

+64-4
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ func (n binaryNode) Type(table typesTable) (Type, error) {
115115
}
116116
return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype)
117117

118+
case "~":
119+
if (isStringType(ltype) || isInterfaceType(ltype)) && (isStringType(rtype) || isInterfaceType(rtype)) {
120+
return textType, nil
121+
}
122+
return nil, fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, n, ltype, rtype)
123+
118124
}
119125

120126
return interfaceType, nil
@@ -173,7 +179,7 @@ func (n methodNode) Type(table typesTable) (Type, error) {
173179
return nil, err
174180
}
175181
}
176-
if t, ok := fieldType(ntype, n.method); ok {
182+
if t, ok := methodType(ntype, n.method); ok {
177183
if f, ok := funcType(t); ok {
178184
return f, nil
179185
}
@@ -220,15 +226,29 @@ func (n conditionalNode) Type(table typesTable) (Type, error) {
220226
if !isBoolType(ctype) && !isInterfaceType(ctype) {
221227
return nil, fmt.Errorf("non-bool %v (type %v) used as condition", n.cond, ctype)
222228
}
223-
_, err = n.exp1.Type(table)
229+
230+
t1, err := n.exp1.Type(table)
224231
if err != nil {
225232
return nil, err
226233
}
227-
_, err = n.exp2.Type(table)
234+
t2, err := n.exp2.Type(table)
228235
if err != nil {
229236
return nil, err
230237
}
231-
return boolType, nil
238+
239+
if t1 == nil && t2 != nil {
240+
return t2, nil
241+
}
242+
if t1 != nil && t2 == nil {
243+
return t1, nil
244+
}
245+
if t1 == nil && t2 == nil {
246+
return nilType, nil
247+
}
248+
if t1.AssignableTo(t2) {
249+
return t1, nil
250+
}
251+
return interfaceType, nil
232252
}
233253

234254
func (n arrayNode) Type(table typesTable) (Type, error) {
@@ -399,6 +419,46 @@ func fieldType(ntype Type, name string) (Type, bool) {
399419
return nil, false
400420
}
401421

422+
func methodType(ntype Type, name string) (Type, bool) {
423+
ntype = dereference(ntype)
424+
if ntype != nil {
425+
switch ntype.Kind() {
426+
case reflect.Interface:
427+
return interfaceType, true
428+
case reflect.Struct:
429+
// First check all struct's methods.
430+
for i := 0; i < ntype.NumMethod(); i++ {
431+
m := ntype.Method(i)
432+
if m.Name == name {
433+
return m.Type, true
434+
}
435+
}
436+
437+
// Second check all struct's fields.
438+
for i := 0; i < ntype.NumField(); i++ {
439+
f := ntype.Field(i)
440+
if !f.Anonymous && f.Name == name {
441+
return f.Type, true
442+
}
443+
}
444+
445+
// Third check fields of embedded structs.
446+
for i := 0; i < ntype.NumField(); i++ {
447+
f := ntype.Field(i)
448+
if f.Anonymous {
449+
if t, ok := methodType(f.Type, name); ok {
450+
return t, true
451+
}
452+
}
453+
}
454+
case reflect.Map:
455+
return ntype.Elem(), true
456+
}
457+
}
458+
459+
return nil, false
460+
}
461+
402462
func indexType(ntype Type) (Type, bool) {
403463
ntype = dereference(ntype)
404464
if ntype == nil {

type_test.go

+16-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ var typeTests = []typeTest{
2525
"Fn(Any)",
2626
"Foo.Fn()",
2727
"true ? Any : Any",
28+
"Str ~ (true ? Str : Str)",
2829
"Ok && Any",
2930
"Str matches 'ok'",
3031
"Str matches Any",
@@ -76,11 +77,13 @@ var typeTests = []typeTest{
7677
"EmbStr == ''",
7778
"Embedded.EmbStr",
7879
"EmbPtrStr == ''",
79-
"EmbeddedPtr ~ Str",
80+
"EmbeddedPtr.EmbPtrStr ~ Str",
8081
"SubStr ~ ''",
8182
"SubEmbedded.SubStr",
8283
"OkFn() and OkFn()",
8384
"Foo.Fn() or Foo.Fn()",
85+
"Method() > 1",
86+
"Embedded.Method() ~ Str",
8487
}
8588

8689
var typeErrorTests = []typeErrorTest{
@@ -284,6 +287,10 @@ var typeErrorTests = []typeErrorTest{
284287
"1 in Foo",
285288
"invalid operation: 1 in Foo (mismatched types float64 and *expr_test.foo)",
286289
},
290+
{
291+
"1 ~ ''",
292+
`invalid operation: 1 ~ "" (mismatched types float64 and string)`,
293+
},
287294
}
288295

289296
type abc interface {
@@ -331,6 +338,14 @@ type payload struct {
331338
NilFn func()
332339
}
333340

341+
func (p payload) Method() int {
342+
return 0
343+
}
344+
345+
func (p Embedded) Method() string {
346+
return ""
347+
}
348+
334349
func TestType(t *testing.T) {
335350
for _, test := range typeTests {
336351
_, err := expr.Parse(string(test), expr.Env(&payload{}))

0 commit comments

Comments
 (0)