Skip to content

Commit 1845891

Browse files
committed
types: add iterators to each of the container types defined in the module.
Signed-off-by: Patrick Jakubowski <patrick.jakubowski@strongdm.com>
1 parent c1b29b6 commit 1845891

File tree

14 files changed

+397
-76
lines changed

14 files changed

+397
-76
lines changed

internal/eval/evalers.go

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -670,15 +670,12 @@ func (n *containsAllEval) Eval(env Env) (types.Value, error) {
670670
if err != nil {
671671
return zeroValue(), err
672672
}
673-
result := true
674-
rhs.Iterate(func(e types.Value) bool {
673+
for e := range rhs.All() {
675674
if !lhs.Contains(e) {
676-
result = false
677-
return false
675+
return types.Boolean(false), nil
678676
}
679-
return true
680-
})
681-
return types.Boolean(result), nil
677+
}
678+
return types.Boolean(true), nil
682679
}
683680

684681
// containsAnyEval
@@ -702,15 +699,12 @@ func (n *containsAnyEval) Eval(env Env) (types.Value, error) {
702699
if err != nil {
703700
return zeroValue(), err
704701
}
705-
result := false
706-
rhs.Iterate(func(e types.Value) bool {
702+
for e := range rhs.All() {
707703
if lhs.Contains(e) {
708-
result = true
709-
return false
704+
return types.Boolean(true), nil
710705
}
711-
return true
712-
})
713-
return types.Boolean(result), nil
706+
}
707+
return types.Boolean(false), nil
714708
}
715709

716710
// isEmptyEval
@@ -950,15 +944,14 @@ func entityInOne(env Env, entity types.EntityUID, parent types.EntityUID) bool {
950944
if fe.Parents.Contains(parent) {
951945
return true
952946
}
953-
fe.Parents.Iterate(func(k types.EntityUID) bool {
947+
for k := range fe.Parents.All() {
954948
p, ok := env.Entities.Get(k)
955949
if !ok || p.Parents.Len() == 0 || k == entity || known.Contains(k) {
956-
return true
950+
continue
957951
}
958952
todo = append(todo, k)
959953
known.Add(k)
960-
return true
961-
})
954+
}
962955
}
963956
if len(todo) == 0 {
964957
return false
@@ -979,15 +972,14 @@ func entityInSet(env Env, entity types.EntityUID, parents mapset.Container[types
979972
if fe.Parents.Intersects(parents) {
980973
return true
981974
}
982-
fe.Parents.Iterate(func(k types.EntityUID) bool {
975+
for k := range fe.Parents.All() {
983976
p, ok := env.Entities.Get(k)
984977
if !ok || p.Parents.Len() == 0 || k == entity || known.Contains(k) {
985-
return true
978+
continue
986979
}
987980
todo = append(todo, k)
988981
known.Add(k)
989-
return true
990-
})
982+
}
991983
}
992984
if len(todo) == 0 {
993985
return false
@@ -1016,17 +1008,12 @@ func doInEval(env Env, lhs types.EntityUID, rhs types.Value) (types.Value, error
10161008
return types.Boolean(entityInOne(env, lhs, rhsv)), nil
10171009
case types.Set:
10181010
query := mapset.Make[types.EntityUID](rhsv.Len())
1019-
var err error
1020-
rhsv.Iterate(func(rhv types.Value) bool {
1021-
var e types.EntityUID
1022-
if e, err = ValueToEntity(rhv); err != nil {
1023-
return false
1011+
for rhv := range rhsv.All() {
1012+
e, err := ValueToEntity(rhv)
1013+
if err != nil {
1014+
return zeroValue(), err
10241015
}
10251016
query.Add(e)
1026-
return true
1027-
})
1028-
if err != nil {
1029-
return zeroValue(), err
10301017
}
10311018
return types.Boolean(entityInSet(env, lhs, query)), nil
10321019
}

internal/mapset/immutable.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package mapset
22

33
import (
44
"encoding/json"
5+
"iter"
56
)
67

78
type ImmutableMapSet[T comparable] MapSet[T]
@@ -22,10 +23,17 @@ func (h ImmutableMapSet[T]) Intersects(o Container[T]) bool {
2223

2324
// Iterate the items in the set, calling callback for each item. If the callback returns false, iteration is halted.
2425
// Iteration order is undefined.
26+
//
27+
// Deprecated: Use All() instead.
2528
func (h ImmutableMapSet[T]) Iterate(callback func(item T) bool) {
2629
MapSet[T](h).Iterate(callback)
2730
}
2831

32+
// All returns an iterator over elements in the set. Iteration order is undefined.
33+
func (h ImmutableMapSet[T]) All() iter.Seq[T] {
34+
return MapSet[T](h).All()
35+
}
36+
2937
func (h ImmutableMapSet[T]) Slice() []T {
3038
return MapSet[T](h).Slice()
3139
}

internal/mapset/immutable_test.go

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,24 @@ func TestImmutableMapSet(t *testing.T) {
7373
t.Run("iterate", func(t *testing.T) {
7474
s1 := Immutable(1, 2, 3)
7575

76-
var items []int
76+
s2 := Make[int]()
7777
s1.Iterate(func(item int) bool {
78-
items = append(items, item)
78+
s2.Add(item)
7979
return true
8080
})
8181

82-
testutil.Equals(t, s1.Equal(Immutable(items...)), true)
82+
testutil.Equals(t, s1.Equal(s2), true)
8383
})
8484

8585
t.Run("iterate break early", func(t *testing.T) {
8686
s1 := Immutable(1, 2, 3)
8787

88-
i := 0
8988
var items []int
9089
s1.Iterate(func(item int) bool {
91-
if i == 2 {
90+
if len(items) == 2 {
9291
return false
9392
}
9493
items = append(items, item)
95-
i++
9694
return true
9795
})
9896

@@ -103,6 +101,35 @@ func TestImmutableMapSet(t *testing.T) {
103101
testutil.Equals(t, s1.Contains(items[1]), true)
104102
})
105103

104+
t.Run("all", func(t *testing.T) {
105+
s1 := Immutable(1, 2, 3)
106+
107+
s2 := Make[int]()
108+
for item := range s1.All() {
109+
s2.Add(item)
110+
}
111+
112+
testutil.Equals(t, s1.Equal(s2), true)
113+
})
114+
115+
t.Run("all break early", func(t *testing.T) {
116+
s1 := Immutable(1, 2, 3)
117+
118+
var items []int
119+
for item := range s1.All() {
120+
if len(items) == 2 {
121+
break
122+
}
123+
items = append(items, item)
124+
}
125+
126+
// Because iteration order is non-deterministic, all we can say is that the right number of items ended up in
127+
// the set and that the items were in the original set.
128+
testutil.Equals(t, len(items), 2)
129+
testutil.Equals(t, s1.Contains(items[0]), true)
130+
testutil.Equals(t, s1.Contains(items[1]), true)
131+
})
132+
106133
t.Run("intersection with overlap", func(t *testing.T) {
107134
s1 := Immutable(1, 2, 3)
108135
s2 := Immutable(2, 3, 4)

internal/mapset/mapset.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"encoding/json"
66
"fmt"
7+
"iter"
78
"slices"
89

910
"golang.org/x/exp/maps"
@@ -92,6 +93,8 @@ func (h MapSet[T]) Intersects(o Container[T]) bool {
9293

9394
// Iterate the items in the set, calling callback for each item. If the callback returns false, iteration is halted.
9495
// Iteration order is undefined.
96+
//
97+
// Deprecated: Use All() instead.
9598
func (h MapSet[T]) Iterate(callback func(item T) bool) {
9699
for item := range h.m {
97100
if !callback(item) {
@@ -100,6 +103,17 @@ func (h MapSet[T]) Iterate(callback func(item T) bool) {
100103
}
101104
}
102105

106+
// All returns an iterator over elements in the set. Iteration order is undefined.
107+
func (h MapSet[T]) All() iter.Seq[T] {
108+
return func(yield func(T) bool) {
109+
for item := range h.m {
110+
if !yield(item) {
111+
return
112+
}
113+
}
114+
}
115+
}
116+
103117
func (h MapSet[T]) Slice() []T {
104118
if h.m == nil {
105119
return nil

internal/mapset/mapset_test.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,12 @@ func TestMapSet(t *testing.T) {
109109
t.Run("iterate break early", func(t *testing.T) {
110110
s1 := FromItems(1, 2, 3)
111111

112-
i := 0
113112
var items []int
114113
s1.Iterate(func(item int) bool {
115-
if i == 2 {
114+
if len(items) == 2 {
116115
return false
117116
}
118117
items = append(items, item)
119-
i++
120118
return true
121119
})
122120

@@ -127,6 +125,35 @@ func TestMapSet(t *testing.T) {
127125
testutil.Equals(t, s1.Contains(items[1]), true)
128126
})
129127

128+
t.Run("all", func(t *testing.T) {
129+
s1 := FromItems(1, 2, 3)
130+
131+
s2 := Make[int]()
132+
for item := range s1.All() {
133+
s2.Add(item)
134+
}
135+
136+
testutil.Equals(t, s1.Equal(s2), true)
137+
})
138+
139+
t.Run("all break early", func(t *testing.T) {
140+
s1 := FromItems(1, 2, 3)
141+
142+
var items []int
143+
for item := range s1.All() {
144+
if len(items) == 2 {
145+
break
146+
}
147+
items = append(items, item)
148+
}
149+
150+
// Because iteration order is non-deterministic, all we can say is that the right number of items ended up in
151+
// the set and that the items were in the original set.
152+
testutil.Equals(t, len(items), 2)
153+
testutil.Equals(t, s1.Contains(items[0]), true)
154+
testutil.Equals(t, s1.Contains(items[1]), true)
155+
})
156+
130157
t.Run("intersection with overlap", func(t *testing.T) {
131158
s1 := FromItems(1, 2, 3)
132159
s2 := FromItems(2, 3, 4)

types.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,24 @@ type Wildcard = types.Wildcard
4343
type EntityGetter = types.EntityGetter
4444
type Value = types.Value
4545

46+
type Request = types.Request
47+
type Decision = types.Decision
48+
type Diagnostic = types.Diagnostic
49+
type DiagnosticReason = types.DiagnosticReason
50+
type DiagnosticError = types.DiagnosticError
51+
52+
const (
53+
Allow = types.Allow
54+
Deny = types.Deny
55+
)
56+
57+
type Effect = types.Effect
58+
59+
const (
60+
Permit = types.Permit
61+
Forbid = types.Forbid
62+
)
63+
4664
// ____ _ _
4765
// / ___|___ _ __ ___| |_ __ _ _ __ | |_ ___
4866
// | | / _ \| '_ \/ __| __/ _` | '_ \| __/ __|

types/entity.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ type Entity struct {
1818
// SDK's behavior.
1919
func (e Entity) MarshalJSON() ([]byte, error) {
2020
parents := make([]ImplicitlyMarshaledEntityUID, 0, e.Parents.Len())
21-
e.Parents.Iterate(func(p EntityUID) bool {
21+
for p := range e.Parents.All() {
2222
parents = append(parents, ImplicitlyMarshaledEntityUID(p))
23-
return true
24-
})
23+
}
2524
slices.SortFunc(parents, func(a, b ImplicitlyMarshaledEntityUID) int {
2625
if cmp := strings.Compare(string(a.Type), string(b.Type)); cmp != 0 {
2726
return cmp

types/record.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/binary"
66
"encoding/json"
77
"hash/fnv"
8+
"iter"
89
"slices"
910
"strconv"
1011

@@ -53,6 +54,8 @@ func (r Record) Len() int {
5354
type RecordIterator func(String, Value) bool
5455

5556
// Iterate calls iter for each key/value pair in the record. Iteration order is non-deterministic.
57+
//
58+
// Deprecated: Use All(), Keys(), or Values() instead.
5659
func (r Record) Iterate(iter RecordIterator) {
5760
for k, v := range r.m {
5861
if !iter(k, v) {
@@ -61,6 +64,39 @@ func (r Record) Iterate(iter RecordIterator) {
6164
}
6265
}
6366

67+
// All returns an iterator over the keys and values in the Record. Iteration order is non-deterministic.
68+
func (r Record) All() iter.Seq2[String, Value] {
69+
return func(yield func(String, Value) bool) {
70+
for k, v := range r.m {
71+
if !yield(k, v) {
72+
break
73+
}
74+
}
75+
}
76+
}
77+
78+
// Keys returns an iterator over the keys in the Record. Iteration order is non-deterministic.
79+
func (r Record) Keys() iter.Seq[String] {
80+
return func(yield func(String) bool) {
81+
for k := range r.m {
82+
if !yield(k) {
83+
break
84+
}
85+
}
86+
}
87+
}
88+
89+
// Values returns an iterator over the keys in the Record. Iteration order is non-deterministic.
90+
func (r Record) Values() iter.Seq[Value] {
91+
return func(yield func(Value) bool) {
92+
for _, v := range r.m {
93+
if !yield(v) {
94+
break
95+
}
96+
}
97+
}
98+
}
99+
64100
// Get returns (v, true) where v is the Value associated with key s, if Record contains key s. Get returns (nil, false)
65101
// if Record does not contain key s.
66102
func (r Record) Get(s String) (Value, bool) {

types/record_internal_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"github.com/cedar-policy/cedar-go/internal/testutil"
77
)
88

9-
func TestRecord(t *testing.T) {
9+
func TestRecordInternal(t *testing.T) {
1010
t.Parallel()
1111

1212
t.Run("hash", func(t *testing.T) {

0 commit comments

Comments
 (0)