Skip to content

Commit 4439fd8

Browse files
committed
Fix working with struct's method
1 parent 0c9427d commit 4439fd8

File tree

3 files changed

+33
-26
lines changed

3 files changed

+33
-26
lines changed

eval_test.go

+11-7
Original file line numberDiff line numberDiff line change
@@ -586,19 +586,19 @@ func TestEval_panic(t *testing.T) {
586586
}
587587

588588
func TestEval_method(t *testing.T) {
589-
env := testEnv{
589+
env := &testEnv{
590590
Hello: "hello",
591591
World: testWorld{
592592
name: []string{"w", "o", "r", "l", "d"},
593593
},
594-
testVersion: testVersion{
594+
testVersion: &testVersion{
595595
version: 2,
596596
},
597597
}
598598

599-
input := `Title(Hello) ~ ' ' ~ (CompareVersion(1, 3) ? World.String() : '')`
599+
input := `Title(Hello) ~ Empty() ~ (CompareVersion(1, 3) ? World.String() : '')`
600600

601-
node, err := expr.Parse(input)
601+
node, err := expr.Parse(input, expr.Env(&testEnv{}))
602602
fmt.Printf("%#v\n", node)
603603
if err != nil {
604604
t.Fatal(err)
@@ -619,7 +619,7 @@ type testVersion struct {
619619
version float64
620620
}
621621

622-
func (c testVersion) CompareVersion(min, max float64) bool {
622+
func (c *testVersion) CompareVersion(min, max float64) bool {
623623
return min < c.version && c.version < max
624624
}
625625

@@ -632,11 +632,15 @@ func (w testWorld) String() string {
632632
}
633633

634634
type testEnv struct {
635-
testVersion
635+
*testVersion
636636
Hello string
637637
World testWorld
638638
}
639639

640-
func (e testEnv) Title(s string) string {
640+
func (e *testEnv) Title(s string) string {
641641
return strings.Title(s)
642642
}
643+
644+
func (e *testEnv) Empty() string {
645+
return " "
646+
}

parser.go

+15-12
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,22 @@ func (p *parser) createTypesTable(i interface{}) typesTable {
140140
v := reflect.ValueOf(i)
141141
t := reflect.TypeOf(i)
142142

143-
t = dereference(t)
144-
if t == nil {
145-
return types
143+
d := t
144+
if t.Kind() == reflect.Ptr {
145+
d = t.Elem()
146146
}
147147

148-
switch t.Kind() {
148+
switch d.Kind() {
149149
case reflect.Struct:
150-
types = p.fromStruct(t)
150+
types = p.fieldsFromStruct(d)
151+
152+
// Methods of struct should be gathered from original struct with pointer,
153+
// as methods maybe declared on pointer receiver. Also this method retrieves
154+
// all embedded structs methods as well, no need to recursion.
155+
for i := 0; i < t.NumMethod(); i++ {
156+
m := t.Method(i)
157+
types[m.Name] = m.Type
158+
}
151159

152160
case reflect.Map:
153161
for _, key := range v.MapKeys() {
@@ -161,7 +169,7 @@ func (p *parser) createTypesTable(i interface{}) typesTable {
161169
return types
162170
}
163171

164-
func (p *parser) fromStruct(t reflect.Type) typesTable {
172+
func (p *parser) fieldsFromStruct(t reflect.Type) typesTable {
165173
types := make(typesTable)
166174
t = dereference(t)
167175
if t == nil {
@@ -174,18 +182,13 @@ func (p *parser) fromStruct(t reflect.Type) typesTable {
174182
f := t.Field(i)
175183

176184
if f.Anonymous {
177-
for name, typ := range p.fromStruct(f.Type) {
185+
for name, typ := range p.fieldsFromStruct(f.Type) {
178186
types[name] = typ
179187
}
180188
}
181189

182190
types[f.Name] = f.Type
183191
}
184-
185-
for i := 0; i < t.NumMethod(); i++ {
186-
m := t.Method(i)
187-
types[m.Name] = m.Type
188-
}
189192
}
190193

191194
return types

runtime.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,14 @@ func extract(val interface{}, i interface{}) (interface{}, bool) {
103103

104104
func getFunc(val interface{}, i interface{}) (interface{}, bool) {
105105
v := reflect.ValueOf(val)
106-
switch v.Kind() {
106+
d := v
107+
if v.Kind() == reflect.Ptr {
108+
d = v.Elem()
109+
}
110+
111+
switch d.Kind() {
107112
case reflect.Map:
108-
value := v.MapIndex(reflect.ValueOf(i))
113+
value := d.MapIndex(reflect.ValueOf(i))
109114
if value.IsValid() && value.CanInterface() {
110115
return value.Interface(), true
111116
}
@@ -119,11 +124,6 @@ func getFunc(val interface{}, i interface{}) (interface{}, bool) {
119124
if value.IsValid() && value.CanInterface() {
120125
return value.Interface(), true
121126
}
122-
case reflect.Ptr:
123-
value := v.Elem()
124-
if value.IsValid() && value.CanInterface() {
125-
return getFunc(value.Interface(), i)
126-
}
127127
}
128128
return nil, false
129129
}

0 commit comments

Comments
 (0)