Skip to content

Commit 9118a4b

Browse files
authored
Merge pull request #92 from strongdm/codex/add-ast-walk-feature-for-isnode-interface
Add AST Inspect feature for IsNode objects
2 parents 159f227 + 910f147 commit 9118a4b

File tree

3 files changed

+159
-10
lines changed

3 files changed

+159
-10
lines changed

x/exp/ast/inspect.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package ast
2+
3+
// Inspect traverses an AST in depth-first order. For each node, the
4+
// provided function fn is called. If fn returns true, Inspect will
5+
// recursively inspect the node's children. Returning false skips the
6+
// children of that node.
7+
func Inspect(n Node, fn func(IsNode) bool) {
8+
inspectNode(n.v, fn)
9+
}
10+
11+
func inspectNode(n IsNode, fn func(IsNode) bool) {
12+
if n == nil {
13+
return
14+
}
15+
if !fn(n) {
16+
return
17+
}
18+
n.inspect(fn)
19+
}

x/exp/ast/inspect_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package ast
2+
3+
import (
4+
"testing"
5+
6+
"github.com/cedar-policy/cedar-go/internal/testutil"
7+
"github.com/cedar-policy/cedar-go/types"
8+
)
9+
10+
func TestInspectCounts(t *testing.T) {
11+
t.Parallel()
12+
leaf1 := NodeValue{Value: types.Long(1)}
13+
leaf2 := NodeValue{Value: types.Long(2)}
14+
cases := []struct {
15+
name string
16+
node Node
17+
want int
18+
}{
19+
{"IfThenElse", NewNode(NodeTypeIfThenElse{If: leaf1, Then: leaf1, Else: leaf1}), 4},
20+
{"Or", NewNode(NodeTypeOr{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
21+
{"And", NewNode(NodeTypeAnd{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
22+
{"LessThan", NewNode(NodeTypeLessThan{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
23+
{"LessThanOrEqual", NewNode(NodeTypeLessThanOrEqual{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
24+
{"GreaterThan", NewNode(NodeTypeGreaterThan{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
25+
{"GreaterThanOrEqual", NewNode(NodeTypeGreaterThanOrEqual{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
26+
{"NotEquals", NewNode(NodeTypeNotEquals{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
27+
{"Equals", NewNode(NodeTypeEquals{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
28+
{"In", NewNode(NodeTypeIn{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
29+
{"HasTag", NewNode(NodeTypeHasTag{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
30+
{"GetTag", NewNode(NodeTypeGetTag{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
31+
{"Contains", NewNode(NodeTypeContains{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
32+
{"ContainsAll", NewNode(NodeTypeContainsAll{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
33+
{"ContainsAny", NewNode(NodeTypeContainsAny{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
34+
{"Add", NewNode(NodeTypeAdd{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
35+
{"Sub", NewNode(NodeTypeSub{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
36+
{"Mult", NewNode(NodeTypeMult{BinaryNode: BinaryNode{Left: leaf1, Right: leaf2}}), 3},
37+
{"Has", NewNode(NodeTypeHas{StrOpNode: StrOpNode{Arg: leaf1, Value: "a"}}), 2},
38+
{"Access", NewNode(NodeTypeAccess{StrOpNode: StrOpNode{Arg: leaf1, Value: "a"}}), 2},
39+
{"Like", NewNode(NodeTypeLike{Arg: leaf1, Value: types.NewPattern(types.Wildcard{})}), 2},
40+
{"Is", NewNode(NodeTypeIs{Left: leaf1, EntityType: "T"}), 2},
41+
{"IsIn", NewNode(NodeTypeIsIn{NodeTypeIs: NodeTypeIs{Left: leaf1, EntityType: "T"}, Entity: leaf2}), 3},
42+
{"Negate", NewNode(NodeTypeNegate{UnaryNode: UnaryNode{Arg: leaf1}}), 2},
43+
{"Not", NewNode(NodeTypeNot{UnaryNode: UnaryNode{Arg: leaf1}}), 2},
44+
{"IsEmpty", NewNode(NodeTypeIsEmpty{UnaryNode: UnaryNode{Arg: leaf1}}), 2},
45+
{"ExtensionCall", NewNode(NodeTypeExtensionCall{Name: "f", Args: []IsNode{leaf1, leaf2}}), 3},
46+
{"Record", NewNode(NodeTypeRecord{Elements: []RecordElementNode{{Key: "k", Value: leaf1}}}), 2},
47+
{"Set", NewNode(NodeTypeSet{Elements: []IsNode{leaf1, leaf2}}), 3},
48+
{"Variable", NewNode(NodeTypeVariable{Name: "v"}), 1},
49+
{"StrOpNode", NewNode(StrOpNode{Arg: leaf1, Value: "a"}), 2},
50+
{"BinaryNode", NewNode(BinaryNode{Left: leaf1, Right: leaf2}), 3},
51+
{"UnaryNode", NewNode(UnaryNode{Arg: leaf1}), 2},
52+
{"Value", NewNode(leaf1), 1},
53+
}
54+
55+
for _, tt := range cases {
56+
tt := tt
57+
t.Run(tt.name, func(t *testing.T) {
58+
t.Parallel()
59+
count := 0
60+
Inspect(tt.node, func(IsNode) bool { count++; return true })
61+
testutil.Equals(t, count, tt.want)
62+
})
63+
}
64+
}
65+
66+
func TestInspectSkipChildren(t *testing.T) {
67+
t.Parallel()
68+
leaf := NewNode(NodeValue{Value: types.Long(1)})
69+
root := NewNode(NodeTypeAnd{BinaryNode: BinaryNode{Left: leaf.v, Right: leaf.v}})
70+
var count int
71+
Inspect(root, func(n IsNode) bool {
72+
count++
73+
if _, ok := n.(NodeTypeAnd); ok {
74+
return false
75+
}
76+
return true
77+
})
78+
testutil.Equals(t, count, 1)
79+
}
80+
81+
func TestInspectNil(t *testing.T) {
82+
t.Parallel()
83+
var c int
84+
Inspect(Node{}, func(IsNode) bool { c++; return true })
85+
testutil.Equals(t, c, 0)
86+
}

x/exp/ast/node.go

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,31 @@ type StrOpNode struct {
2121
Value types.String
2222
}
2323

24-
func (n StrOpNode) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation{}
24+
func (n StrOpNode) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
25+
func (n StrOpNode) inspect(fn func(IsNode) bool) {
26+
inspectNode(n.Arg, fn)
27+
}
2528

2629
type BinaryNode struct {
2730
Left, Right IsNode
2831
}
2932

30-
func (n BinaryNode) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation{}
33+
func (n BinaryNode) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
34+
func (n BinaryNode) inspect(fn func(IsNode) bool) {
35+
inspectNode(n.Left, fn)
36+
inspectNode(n.Right, fn)
37+
}
3138

3239
type NodeTypeIfThenElse struct {
3340
If, Then, Else IsNode
3441
}
3542

3643
func (n NodeTypeIfThenElse) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
44+
func (n NodeTypeIfThenElse) inspect(fn func(IsNode) bool) {
45+
inspectNode(n.If, fn)
46+
inspectNode(n.Then, fn)
47+
inspectNode(n.Else, fn)
48+
}
3749

3850
type NodeTypeOr struct{ BinaryNode }
3951

@@ -74,20 +86,31 @@ type NodeTypeLike struct {
7486
Value types.Pattern
7587
}
7688

77-
func (n NodeTypeLike) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation{}
89+
func (n NodeTypeLike) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
90+
func (n NodeTypeLike) inspect(fn func(IsNode) bool) {
91+
inspectNode(n.Arg, fn)
92+
}
7893

7994
type NodeTypeIs struct {
8095
Left IsNode
8196
EntityType types.EntityType
8297
}
8398

84-
func (n NodeTypeIs) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation{}
99+
func (n NodeTypeIs) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
100+
func (n NodeTypeIs) inspect(fn func(IsNode) bool) {
101+
inspectNode(n.Left, fn)
102+
}
85103

86104
type NodeTypeIsIn struct {
87105
NodeTypeIs
88106
Entity IsNode
89107
}
90108

109+
func (n NodeTypeIsIn) inspect(fn func(IsNode) bool) {
110+
n.NodeTypeIs.inspect(fn)
111+
inspectNode(n.Entity, fn)
112+
}
113+
91114
type AddNode struct{}
92115

93116
type NodeTypeSub struct {
@@ -106,7 +129,10 @@ type UnaryNode struct {
106129
Arg IsNode
107130
}
108131

109-
func (n UnaryNode) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation{}
132+
func (n UnaryNode) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
133+
func (n UnaryNode) inspect(fn func(IsNode) bool) {
134+
inspectNode(n.Arg, fn)
135+
}
110136

111137
type NodeTypeNegate struct{ UnaryNode }
112138
type NodeTypeNot struct{ UnaryNode }
@@ -120,7 +146,12 @@ type NodeTypeExtensionCall struct {
120146
Args []IsNode
121147
}
122148

123-
func (n NodeTypeExtensionCall) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation{}
149+
func (n NodeTypeExtensionCall) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
150+
func (n NodeTypeExtensionCall) inspect(fn func(IsNode) bool) {
151+
for _, a := range n.Args {
152+
inspectNode(a, fn)
153+
}
154+
}
124155

125156
func stripNodes(args []Node) []IsNode {
126157
if args == nil {
@@ -170,7 +201,8 @@ type NodeValue struct {
170201
Value types.Value
171202
}
172203

173-
func (n NodeValue) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation{}
204+
func (n NodeValue) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
205+
func (n NodeValue) inspect(func(IsNode) bool) { _ = 0 } // No-op statements injected for code coverage instrumentation
174206

175207
type RecordElementNode struct {
176208
Key types.String
@@ -181,20 +213,32 @@ type NodeTypeRecord struct {
181213
Elements []RecordElementNode
182214
}
183215

184-
func (n NodeTypeRecord) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation{}
216+
func (n NodeTypeRecord) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
217+
func (n NodeTypeRecord) inspect(fn func(IsNode) bool) {
218+
for _, e := range n.Elements {
219+
inspectNode(e.Value, fn)
220+
}
221+
}
185222

186223
type NodeTypeSet struct {
187224
Elements []IsNode
188225
}
189226

190-
func (n NodeTypeSet) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation{}
227+
func (n NodeTypeSet) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
228+
func (n NodeTypeSet) inspect(fn func(IsNode) bool) {
229+
for _, e := range n.Elements {
230+
inspectNode(e, fn)
231+
}
232+
}
191233

192234
type NodeTypeVariable struct {
193235
Name types.String
194236
}
195237

196-
func (n NodeTypeVariable) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation{}
238+
func (n NodeTypeVariable) isNode() { _ = 0 } // No-op statement injected for code coverage instrumentation
239+
func (n NodeTypeVariable) inspect(func(IsNode) bool) { _ = 0 } // No-op statements injected for code coverage instrumentation
197240

198241
type IsNode interface {
199242
isNode()
243+
inspect(func(IsNode) bool)
200244
}

0 commit comments

Comments
 (0)