@@ -165,6 +165,8 @@ func (v *checker) visit(node ast.Node) (reflect.Type, info) {
165
165
t , i = v .MapNode (n )
166
166
case * ast.PairNode :
167
167
t , i = v .PairNode (n )
168
+ case * ast.CompareNode :
169
+ t , i = v .CompareNode (n )
168
170
default :
169
171
panic (fmt .Sprintf ("undefined node type (%T)" , node ))
170
172
}
@@ -272,17 +274,12 @@ func (v *checker) UnaryNode(node *ast.UnaryNode) (reflect.Type, info) {
272
274
273
275
func (v * checker ) BinaryNode (node * ast.BinaryNode ) (reflect.Type , info ) {
274
276
l , _ := v .visit (node .Left )
275
- r , ri := v .visit (node .Right )
277
+ r , _ := v .visit (node .Right )
276
278
277
279
l = deref .Type (l )
278
280
r = deref .Type (r )
279
281
280
282
switch node .Operator {
281
- case "==" , "!=" :
282
- if isComparable (l , r ) {
283
- return boolType , info {}
284
- }
285
-
286
283
case "or" , "||" , "and" , "&&" :
287
284
if isBool (l ) && isBool (r ) {
288
285
return boolType , info {}
@@ -291,20 +288,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
291
288
return boolType , info {}
292
289
}
293
290
294
- case "<" , ">" , ">=" , "<=" :
295
- if isNumber (l ) && isNumber (r ) {
296
- return boolType , info {}
297
- }
298
- if isString (l ) && isString (r ) {
299
- return boolType , info {}
300
- }
301
- if isTime (l ) && isTime (r ) {
302
- return boolType , info {}
303
- }
304
- if or (l , r , isNumber , isString , isTime ) {
305
- return boolType , info {}
306
- }
307
-
308
291
case "-" :
309
292
if isNumber (l ) && isNumber (r ) {
310
293
return combined (l , r ), info {}
@@ -368,60 +351,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
368
351
return anyType , info {}
369
352
}
370
353
371
- case "in" :
372
- if (isString (l ) || isAny (l )) && isStruct (r ) {
373
- return boolType , info {}
374
- }
375
- if isMap (r ) {
376
- if l == nil { // It is possible to compare with nil.
377
- return boolType , info {}
378
- }
379
- if ! isAny (l ) && ! l .AssignableTo (r .Key ()) {
380
- return v .error (node , "cannot use %v as type %v in map key" , l , r .Key ())
381
- }
382
- return boolType , info {}
383
- }
384
- if isArray (r ) {
385
- if l == nil { // It is possible to compare with nil.
386
- return boolType , info {}
387
- }
388
- if ! isComparable (l , r .Elem ()) {
389
- return v .error (node , "cannot use %v as type %v in array" , l , r .Elem ())
390
- }
391
- if ! isComparable (l , ri .elem ) {
392
- return v .error (node , "cannot use %v as type %v in array" , l , ri .elem )
393
- }
394
- return boolType , info {}
395
- }
396
- if isAny (l ) && anyOf (r , isString , isArray , isMap ) {
397
- return boolType , info {}
398
- }
399
- if isAny (r ) {
400
- return boolType , info {}
401
- }
402
-
403
- case "matches" :
404
- if s , ok := node .Right .(* ast.StringNode ); ok {
405
- _ , err := regexp .Compile (s .Value )
406
- if err != nil {
407
- return v .error (node , err .Error ())
408
- }
409
- }
410
- if isString (l ) && isString (r ) {
411
- return boolType , info {}
412
- }
413
- if or (l , r , isString ) {
414
- return boolType , info {}
415
- }
416
-
417
- case "contains" , "startsWith" , "endsWith" :
418
- if isString (l ) && isString (r ) {
419
- return boolType , info {}
420
- }
421
- if or (l , r , isString ) {
422
- return boolType , info {}
423
- }
424
-
425
354
case ".." :
426
355
ret := reflect .SliceOf (integerType )
427
356
if isInteger (l ) && isInteger (r ) {
@@ -448,7 +377,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
448
377
449
378
default :
450
379
return v .error (node , "unknown operator (%v)" , node .Operator )
451
-
452
380
}
453
381
454
382
return v .error (node , `invalid operation: %v (mismatched types %v and %v)` , node .Operator , l , r )
@@ -1207,3 +1135,95 @@ func (v *checker) PairNode(node *ast.PairNode) (reflect.Type, info) {
1207
1135
v .visit (node .Value )
1208
1136
return nilType , info {}
1209
1137
}
1138
+
1139
+ func (v * checker ) CompareNode (node * ast.CompareNode ) (reflect.Type , info ) {
1140
+ nodeLeft := node .Left
1141
+ opIdx := 0
1142
+ operatorOverride := false
1143
+ for i , comparator := range node .Comparators {
1144
+ op := node .Operators [opIdx ]
1145
+ if negate := op == "not" ; negate {
1146
+ opIdx ++
1147
+ op = node .Operators [opIdx ]
1148
+ }
1149
+ if op == "&&" {
1150
+ if ! operatorOverride {
1151
+ operatorOverride = true
1152
+ }
1153
+ } else if err := v .compareNode (op , nodeLeft , comparator , i ); err != nil {
1154
+ return v .error (comparator , err .Error ())
1155
+ }
1156
+ opIdx ++
1157
+ nodeLeft = comparator
1158
+ }
1159
+ if operatorOverride {
1160
+ return anyType , info {}
1161
+ }
1162
+ return boolType , info {}
1163
+ }
1164
+
1165
+ func (v * checker ) compareNode (op string , nodeLeft , nodeRight ast.Node , index int ) error {
1166
+ l , _ := v .visit (nodeLeft )
1167
+ r , ri := v .visit (nodeRight )
1168
+ l = deref .Type (l )
1169
+ r = deref .Type (r )
1170
+ switch op {
1171
+ case "==" , "!=" :
1172
+ if (isBool (r ) && index > 0 ) || isComparable (l , r ) {
1173
+ return nil
1174
+ }
1175
+ case "<" , ">" , ">=" , "<=" :
1176
+ if isNumber (l ) && isNumber (r ) ||
1177
+ isString (l ) && isString (r ) ||
1178
+ isTime (l ) && isTime (r ) ||
1179
+ or (l , r , isNumber , isString , isTime ) {
1180
+ return nil
1181
+ }
1182
+ case "in" :
1183
+ if (isString (l ) || isAny (l )) && isStruct (r ) {
1184
+ return nil
1185
+ }
1186
+ if isMap (r ) {
1187
+ if l == nil { // It is possible to compare with nil.
1188
+ return nil
1189
+ }
1190
+ if ! isAny (l ) && ! l .AssignableTo (r .Key ()) {
1191
+ return fmt .Errorf ("cannot use %v as type %v in map key" , l , r .Key ())
1192
+ }
1193
+ return nil
1194
+ }
1195
+ if isArray (r ) {
1196
+ if l == nil { // It is possible to compare with nil.
1197
+ return nil
1198
+ }
1199
+ if ! isComparable (l , r .Elem ()) {
1200
+ return fmt .Errorf ("cannot use %v as type %v in array" , l , r .Elem ())
1201
+ }
1202
+ if ! isComparable (l , ri .elem ) {
1203
+ return fmt .Errorf ("cannot use %v as type %v in array" , l , ri .elem )
1204
+ }
1205
+ return nil
1206
+ }
1207
+ if (isAny (l ) && anyOf (r , isString , isArray , isMap )) || isAny (r ) {
1208
+ return nil
1209
+ }
1210
+
1211
+ case "matches" :
1212
+ if s , ok := nodeRight .(* ast.StringNode ); ok {
1213
+ if _ , err := regexp .Compile (s .Value ); err != nil {
1214
+ return err
1215
+ }
1216
+ }
1217
+ if (isString (l ) && isString (r )) || or (l , r , isString ) {
1218
+ return nil
1219
+ }
1220
+ case "contains" , "startsWith" , "endsWith" :
1221
+ if isString (l ) && isString (r ) ||
1222
+ or (l , r , isString ) {
1223
+ return nil
1224
+ }
1225
+ default :
1226
+ return fmt .Errorf ("unknown operator (%v)" , op )
1227
+ }
1228
+ return fmt .Errorf (`invalid operation: %v (mismatched types %v and %v)` , op , l , r )
1229
+ }
0 commit comments