Skip to content

Commit 6bbe56c

Browse files
committed
Enable comparison feature like python
1 parent 7e6e6f5 commit 6bbe56c

File tree

13 files changed

+665
-494
lines changed

13 files changed

+665
-494
lines changed

ast/node.go

+8
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,11 @@ type PairNode struct {
214214
Key Node // Key of the pair.
215215
Value Node // Value of the pair.
216216
}
217+
218+
// CompareNode represents comparison
219+
type CompareNode struct {
220+
base
221+
Left Node // Left represents the left-hand side of the comparison operation
222+
Operators []string // Operators is a list of comparison operator tokens used in the comparison.
223+
Comparators []Node // Comparators representing the right-hand sides of the comparison operation
224+
}

ast/print.go

+31
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,34 @@ func (n *PairNode) String() string {
219219
}
220220
return fmt.Sprintf("(%s): %s", n.Key.String(), n.Value.String())
221221
}
222+
223+
func (n *CompareNode) string(node Node) string {
224+
switch v := node.(type) {
225+
case *BinaryNode, *CompareNode:
226+
return fmt.Sprintf("(%s)", v)
227+
default:
228+
return v.String()
229+
}
230+
}
231+
232+
func (n *CompareNode) String() string {
233+
var builder strings.Builder
234+
builder.WriteString(n.string(n.Left))
235+
opIdx := 0
236+
for i := 0; i < len(n.Comparators); i++ {
237+
if op := n.Operators[opIdx]; op != "&&" {
238+
builder.WriteByte(' ')
239+
builder.WriteString(op)
240+
if op == "not" {
241+
opIdx++
242+
builder.WriteByte(' ')
243+
builder.WriteString(n.Operators[opIdx])
244+
}
245+
builder.WriteByte(' ')
246+
builder.WriteString(n.string(n.Comparators[i]))
247+
}
248+
opIdx++
249+
}
250+
251+
return builder.String()
252+
}

ast/print_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func TestPrint(t *testing.T) {
4141
{`a == b`, `a == b`},
4242
{`a matches b`, `a matches b`},
4343
{`a in b`, `a in b`},
44-
{`a not in b`, `not (a in b)`},
44+
{`a not in b`, `a not in b`},
4545
{`a and b`, `a and b`},
4646
{`a or b`, `a or b`},
4747
{`a or b and c`, `a or (b and c)`},

ast/visitor.go

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ func Walk(node *Node, v Visitor) {
6666
case *PairNode:
6767
Walk(&n.Key, v)
6868
Walk(&n.Value, v)
69+
case *CompareNode:
70+
Walk(&n.Left, v)
71+
for i := range n.Comparators {
72+
Walk(&n.Comparators[i], v)
73+
}
6974
default:
7075
panic(fmt.Sprintf("undefined node type (%T)", node))
7176
}

checker/checker.go

+95-75
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ func (v *checker) visit(node ast.Node) (reflect.Type, info) {
165165
t, i = v.MapNode(n)
166166
case *ast.PairNode:
167167
t, i = v.PairNode(n)
168+
case *ast.CompareNode:
169+
t, i = v.CompareNode(n)
168170
default:
169171
panic(fmt.Sprintf("undefined node type (%T)", node))
170172
}
@@ -272,17 +274,12 @@ func (v *checker) UnaryNode(node *ast.UnaryNode) (reflect.Type, info) {
272274

273275
func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
274276
l, _ := v.visit(node.Left)
275-
r, ri := v.visit(node.Right)
277+
r, _ := v.visit(node.Right)
276278

277279
l = deref.Type(l)
278280
r = deref.Type(r)
279281

280282
switch node.Operator {
281-
case "==", "!=":
282-
if isComparable(l, r) {
283-
return boolType, info{}
284-
}
285-
286283
case "or", "||", "and", "&&":
287284
if isBool(l) && isBool(r) {
288285
return boolType, info{}
@@ -291,20 +288,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
291288
return boolType, info{}
292289
}
293290

294-
case "<", ">", ">=", "<=":
295-
if isNumber(l) && isNumber(r) {
296-
return boolType, info{}
297-
}
298-
if isString(l) && isString(r) {
299-
return boolType, info{}
300-
}
301-
if isTime(l) && isTime(r) {
302-
return boolType, info{}
303-
}
304-
if or(l, r, isNumber, isString, isTime) {
305-
return boolType, info{}
306-
}
307-
308291
case "-":
309292
if isNumber(l) && isNumber(r) {
310293
return combined(l, r), info{}
@@ -368,60 +351,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
368351
return anyType, info{}
369352
}
370353

371-
case "in":
372-
if (isString(l) || isAny(l)) && isStruct(r) {
373-
return boolType, info{}
374-
}
375-
if isMap(r) {
376-
if l == nil { // It is possible to compare with nil.
377-
return boolType, info{}
378-
}
379-
if !isAny(l) && !l.AssignableTo(r.Key()) {
380-
return v.error(node, "cannot use %v as type %v in map key", l, r.Key())
381-
}
382-
return boolType, info{}
383-
}
384-
if isArray(r) {
385-
if l == nil { // It is possible to compare with nil.
386-
return boolType, info{}
387-
}
388-
if !isComparable(l, r.Elem()) {
389-
return v.error(node, "cannot use %v as type %v in array", l, r.Elem())
390-
}
391-
if !isComparable(l, ri.elem) {
392-
return v.error(node, "cannot use %v as type %v in array", l, ri.elem)
393-
}
394-
return boolType, info{}
395-
}
396-
if isAny(l) && anyOf(r, isString, isArray, isMap) {
397-
return boolType, info{}
398-
}
399-
if isAny(r) {
400-
return boolType, info{}
401-
}
402-
403-
case "matches":
404-
if s, ok := node.Right.(*ast.StringNode); ok {
405-
_, err := regexp.Compile(s.Value)
406-
if err != nil {
407-
return v.error(node, err.Error())
408-
}
409-
}
410-
if isString(l) && isString(r) {
411-
return boolType, info{}
412-
}
413-
if or(l, r, isString) {
414-
return boolType, info{}
415-
}
416-
417-
case "contains", "startsWith", "endsWith":
418-
if isString(l) && isString(r) {
419-
return boolType, info{}
420-
}
421-
if or(l, r, isString) {
422-
return boolType, info{}
423-
}
424-
425354
case "..":
426355
ret := reflect.SliceOf(integerType)
427356
if isInteger(l) && isInteger(r) {
@@ -448,7 +377,6 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
448377

449378
default:
450379
return v.error(node, "unknown operator (%v)", node.Operator)
451-
452380
}
453381

454382
return v.error(node, `invalid operation: %v (mismatched types %v and %v)`, node.Operator, l, r)
@@ -1207,3 +1135,95 @@ func (v *checker) PairNode(node *ast.PairNode) (reflect.Type, info) {
12071135
v.visit(node.Value)
12081136
return nilType, info{}
12091137
}
1138+
1139+
func (v *checker) CompareNode(node *ast.CompareNode) (reflect.Type, info) {
1140+
nodeLeft := node.Left
1141+
opIdx := 0
1142+
operatorOverride := false
1143+
for i, comparator := range node.Comparators {
1144+
op := node.Operators[opIdx]
1145+
if negate := op == "not"; negate {
1146+
opIdx++
1147+
op = node.Operators[opIdx]
1148+
}
1149+
if op == "&&" {
1150+
if !operatorOverride {
1151+
operatorOverride = true
1152+
}
1153+
} else if err := v.compareNode(op, nodeLeft, comparator, i); err != nil {
1154+
return v.error(comparator, err.Error())
1155+
}
1156+
opIdx++
1157+
nodeLeft = comparator
1158+
}
1159+
if operatorOverride {
1160+
return anyType, info{}
1161+
}
1162+
return boolType, info{}
1163+
}
1164+
1165+
func (v *checker) compareNode(op string, nodeLeft, nodeRight ast.Node, index int) error {
1166+
l, _ := v.visit(nodeLeft)
1167+
r, ri := v.visit(nodeRight)
1168+
l = deref.Type(l)
1169+
r = deref.Type(r)
1170+
switch op {
1171+
case "==", "!=":
1172+
if (isBool(r) && index > 0) || isComparable(l, r) {
1173+
return nil
1174+
}
1175+
case "<", ">", ">=", "<=":
1176+
if isNumber(l) && isNumber(r) ||
1177+
isString(l) && isString(r) ||
1178+
isTime(l) && isTime(r) ||
1179+
or(l, r, isNumber, isString, isTime) {
1180+
return nil
1181+
}
1182+
case "in":
1183+
if (isString(l) || isAny(l)) && isStruct(r) {
1184+
return nil
1185+
}
1186+
if isMap(r) {
1187+
if l == nil { // It is possible to compare with nil.
1188+
return nil
1189+
}
1190+
if !isAny(l) && !l.AssignableTo(r.Key()) {
1191+
return fmt.Errorf("cannot use %v as type %v in map key", l, r.Key())
1192+
}
1193+
return nil
1194+
}
1195+
if isArray(r) {
1196+
if l == nil { // It is possible to compare with nil.
1197+
return nil
1198+
}
1199+
if !isComparable(l, r.Elem()) {
1200+
return fmt.Errorf("cannot use %v as type %v in array", l, r.Elem())
1201+
}
1202+
if !isComparable(l, ri.elem) {
1203+
return fmt.Errorf("cannot use %v as type %v in array", l, ri.elem)
1204+
}
1205+
return nil
1206+
}
1207+
if (isAny(l) && anyOf(r, isString, isArray, isMap)) || isAny(r) {
1208+
return nil
1209+
}
1210+
1211+
case "matches":
1212+
if s, ok := nodeRight.(*ast.StringNode); ok {
1213+
if _, err := regexp.Compile(s.Value); err != nil {
1214+
return err
1215+
}
1216+
}
1217+
if (isString(l) && isString(r)) || or(l, r, isString) {
1218+
return nil
1219+
}
1220+
case "contains", "startsWith", "endsWith":
1221+
if isString(l) && isString(r) ||
1222+
or(l, r, isString) {
1223+
return nil
1224+
}
1225+
default:
1226+
return fmt.Errorf("unknown operator (%v)", op)
1227+
}
1228+
return fmt.Errorf(`invalid operation: %v (mismatched types %v and %v)`, op, l, r)
1229+
}

0 commit comments

Comments
 (0)