Skip to content

Commit ad915e4

Browse files
committed
Add better support for methods and string comparison
1 parent 4ad9e52 commit ad915e4

File tree

5 files changed

+91
-39
lines changed

5 files changed

+91
-39
lines changed

checker/checker.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type {
155155
if (isFloat(l) && isIntegerNode(node.Right)) || (isIntegerNode(node.Left) && isFloat(r)) {
156156
return boolType
157157
}
158+
if isString(l) && isString(r) {
159+
return boolType
160+
}
158161

159162
case "or", "||", "and", "&&":
160163
if (isBool(l) || isInterface(l)) && (isBool(r) || isInterface(r)) {
@@ -179,6 +182,9 @@ func (v *visitor) BinaryNode(node *ast.BinaryNode) reflect.Type {
179182
if (isFloat(l) && isIntegerNode(node.Right)) || (isIntegerNode(node.Left) && isFloat(r)) {
180183
return boolType
181184
}
185+
if isString(l) && isString(r) {
186+
return boolType
187+
}
182188

183189
case "/", "-", "*", "**":
184190
if (isInteger(l) || isInterface(l)) && (isInteger(r) || isInterface(r)) {

checker/checker_test.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"regexp"
99
"strings"
1010
"testing"
11+
"time"
1112
)
1213

1314
func TestVisitor_FunctionNode(t *testing.T) {
@@ -31,7 +32,12 @@ func TestVisitor_MethodNode(t *testing.T) {
3132
var err error
3233

3334
env := &mockEnv{}
34-
input := `Var.Set(1, 0.5) + Var.Add(2) + Var.Any(true) + Var.Get() + Var.Sub(3)`
35+
input := `Var.Set(1, 0.5)
36+
+ Var.Add(2)
37+
+ Var.Any(true)
38+
+ Var.Get()
39+
+ Var.Sub(3)
40+
+ (Duration.String() == "" ? 1 : 0)`
3541

3642
tree, err := parser.Parse(input)
3743
assert.NoError(t, err)
@@ -65,10 +71,11 @@ func TestVisitor_BuiltinNode(t *testing.T) {
6571

6672
type mockEnv struct {
6773
*mockEmbed
68-
Add func(int64) int64
69-
Any interface{}
70-
Var *mockVar
71-
Tickets []mockTicket
74+
Add func(int64) int64
75+
Any interface{}
76+
Var *mockVar
77+
Tickets []mockTicket
78+
Duration time.Duration
7279
}
7380

7481
func (f *mockEnv) Set(v int64, any interface{}) int64 {
@@ -163,6 +170,7 @@ func TestCheck(t *testing.T) {
163170
"Foo.Fn() or Foo.Fn()",
164171
"Method(Foo.Bar) > 1",
165172
"Embedded.Method() + Str",
173+
`"a" < "b"`,
166174
}
167175
for _, test := range typeTests {
168176
var err error

checker/types.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,15 @@ func fieldType(ntype reflect.Type, name string) (reflect.Type, bool) {
182182

183183
func methodType(t reflect.Type, name string) (reflect.Type, bool, bool) {
184184
if t != nil {
185+
// First, check methods defined on type itself,
186+
// independent of which type it is.
187+
for i := 0; i < t.NumMethod(); i++ {
188+
m := t.Method(i)
189+
if m.Name == name {
190+
return m.Type, true, true
191+
}
192+
}
193+
185194
d := t
186195
if t.Kind() == reflect.Ptr {
187196
d = t.Elem()
@@ -191,23 +200,15 @@ func methodType(t reflect.Type, name string) (reflect.Type, bool, bool) {
191200
case reflect.Interface:
192201
return interfaceType, false, true
193202
case reflect.Struct:
194-
// First check all struct's methods.
195-
for i := 0; i < t.NumMethod(); i++ {
196-
m := t.Method(i)
197-
if m.Name == name {
198-
return m.Type, true, true
199-
}
200-
}
201-
202-
// Second check all struct's fields.
203+
// First, check all struct's fields.
203204
for i := 0; i < d.NumField(); i++ {
204205
f := d.Field(i)
205206
if !f.Anonymous && f.Name == name {
206207
return f.Type, false, true
207208
}
208209
}
209210

210-
// Third check fields of embedded structs.
211+
// Second, check fields of embedded structs.
211212
for i := 0; i < d.NumField(); i++ {
212213
f := d.Field(i)
213214
if f.Anonymous {

vm/runtime.go

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ func fetch(from interface{}, i interface{}) interface{} {
4848

4949
func fetchFn(from interface{}, name string) reflect.Value {
5050
v := reflect.ValueOf(from)
51+
52+
// Methods can be defined on any type.
53+
if v.NumMethod() > 0 {
54+
method := v.MethodByName(name)
55+
if method.IsValid() {
56+
return method
57+
}
58+
}
59+
5160
d := v
5261
if v.Kind() == reflect.Ptr {
5362
d = v.Elem()
@@ -59,19 +68,7 @@ func fetchFn(from interface{}, name string) reflect.Value {
5968
if value.IsValid() && value.CanInterface() {
6069
return value.Elem()
6170
}
62-
// A map may have method too.
63-
if v.NumMethod() > 0 {
64-
method := v.MethodByName(name)
65-
if method.IsValid() {
66-
return method
67-
}
68-
}
6971
case reflect.Struct:
70-
method := v.MethodByName(name)
71-
if method.IsValid() {
72-
return method
73-
}
74-
7572
// If struct has not method, maybe it has func field.
7673
// To access this field we need dereference value.
7774
value := d.FieldByName(name)
@@ -222,6 +219,9 @@ func equal(a, b interface{}) bool {
222219
case uint64:
223220
return x == uint64(cast(b))
224221

222+
case string:
223+
return x == b.(string)
224+
225225
default:
226226
return reflect.DeepEqual(a, b)
227227
}
@@ -271,6 +271,9 @@ func less(a, b interface{}) interface{} {
271271
case uint64:
272272
return x < uint64(cast(b))
273273

274+
case string:
275+
return x < b.(string)
276+
274277
default:
275278
panic(fmt.Sprintf("invalid operation: %T < %T", a, b))
276279
}
@@ -320,6 +323,9 @@ func more(a, b interface{}) interface{} {
320323
case uint64:
321324
return x > uint64(cast(b))
322325

326+
case string:
327+
return x > b.(string)
328+
323329
default:
324330
panic(fmt.Sprintf("invalid operation: %T > %T", a, b))
325331
}
@@ -369,6 +375,9 @@ func lessOrEqual(a, b interface{}) interface{} {
369375
case uint64:
370376
return x <= uint64(cast(b))
371377

378+
case string:
379+
return x <= b.(string)
380+
372381
default:
373382
panic(fmt.Sprintf("invalid operation: %T <= %T", a, b))
374383
}
@@ -418,6 +427,9 @@ func moreOrEqual(a, b interface{}) interface{} {
418427
case uint64:
419428
return x >= uint64(cast(b))
420429

430+
case string:
431+
return x >= b.(string)
432+
421433
default:
422434
panic(fmt.Sprintf("invalid operation: %T >= %T", a, b))
423435
}

vm/vm_test.go

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@ import (
99
"github.com/stretchr/testify/assert"
1010
"github.com/stretchr/testify/require"
1111
"testing"
12+
"time"
1213
)
1314

1415
func TestRun_debug(t *testing.T) {
1516
var test = struct {
1617
input string
1718
output interface{}
1819
}{
19-
`filter([1,2,3], {# > 2})`,
20-
[]interface{}{int64(3)},
20+
`Now.Sub(BirthDay).String() != Duration("1h").String()`,
21+
true,
2122
}
2223

2324
env := &mockEnv{}
@@ -211,6 +212,18 @@ func TestRun(t *testing.T) {
211212
`one([1,1,0,1], {# == 0}) and not one([1,0,0,1], {# == 0})`,
212213
true,
213214
},
215+
{
216+
`Now.After(BirthDay)`,
217+
true,
218+
},
219+
{
220+
`"a" < "b"`,
221+
true,
222+
},
223+
{
224+
`Now.Sub(Now).String() == Duration("0s").String()`,
225+
true,
226+
},
214227
}
215228

216229
env := &mockEnv{
@@ -226,6 +239,8 @@ func TestRun(t *testing.T) {
226239
Ticket: &mockTicket{
227240
Price: 100,
228241
},
242+
BirthDay: time.Date(2017, time.October, 23, 18, 30, 0, 0, time.UTC),
243+
Now: time.Now(),
229244
}
230245

231246
for _, test := range tests {
@@ -246,16 +261,18 @@ func TestRun(t *testing.T) {
246261
}
247262

248263
type mockEnv struct {
249-
Any interface{}
250-
Int int
251-
Int32 int32
252-
Int64 int64
253-
Uint64 uint64
254-
Float64 float64
255-
Bool bool
256-
String string
257-
Array []int
258-
Ticket *mockTicket
264+
Any interface{}
265+
Int int
266+
Int32 int32
267+
Int64 int64
268+
Uint64 uint64
269+
Float64 float64
270+
Bool bool
271+
String string
272+
Array []int
273+
Ticket *mockTicket
274+
BirthDay time.Time
275+
Now time.Time
259276
}
260277

261278
func (e *mockEnv) GetInt() int {
@@ -266,6 +283,14 @@ func (*mockEnv) Add(a, b int64) int {
266283
return int(a + b)
267284
}
268285

286+
func (*mockEnv) Duration(s string) time.Duration {
287+
d, err := time.ParseDuration(s)
288+
if err != nil {
289+
panic(err)
290+
}
291+
return d
292+
}
293+
269294
type mockTicket struct {
270295
Price int
271296
}

0 commit comments

Comments
 (0)