Skip to content

Commit c3c8479

Browse files
committed
Add DOT graph export for EntityMap
Signed-off-by: Pierre-Henri Symoneaux <pierre-henri.symoneaux@ovhcloud.com>
1 parent ab3e2e2 commit c3c8479

File tree

2 files changed

+246
-0
lines changed

2 files changed

+246
-0
lines changed

x/exp/dot/dot.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package dot
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"iter"
7+
"strconv"
8+
9+
"github.com/cedar-policy/cedar-go/types"
10+
)
11+
12+
// Write takes an entity iterator and writes a DOT graph representing entities relationship.
13+
//
14+
// This function only returns an error on a failing write to w, so it is infallible if the Writer implementation cannot fail.
15+
func Write(w io.Writer, entities iter.Seq[types.Entity]) error {
16+
// write prelude
17+
if _, err := fmt.Fprintln(w, "strict digraph {\n\tordering=\"out\"\n\tnode[shape=box]"); err != nil {
18+
return err
19+
}
20+
21+
// write clusters (subgraphs)
22+
entitiesByType := getEntitiesByEntityType(entities)
23+
24+
for et, entities := range entitiesByType {
25+
if _, err := fmt.Fprintf(w, "\tsubgraph \"cluster_%s\" {\n\t\tlabel=%s\n", et, toDotID(string(et))); err != nil {
26+
return err
27+
}
28+
for _, entity := range entities {
29+
if _, err := fmt.Fprintf(w, "\t\t%s [label=%s]\n", toDotID(entity.UID.String()), toDotID(entity.UID.ID.String())); err != nil {
30+
return err
31+
}
32+
}
33+
if _, err := fmt.Fprintln(w, "\t}"); err != nil {
34+
return err
35+
}
36+
}
37+
38+
// adding edges
39+
for entity := range entities {
40+
for ancestor := range entity.Parents.All() {
41+
if _, err := fmt.Fprintf(w, "\t%s -> %s\n", toDotID(entity.UID.String()), toDotID(ancestor.String())); err != nil {
42+
return err
43+
}
44+
}
45+
}
46+
if _, err := fmt.Fprintln(w, "}"); err != nil {
47+
return err
48+
}
49+
return nil
50+
}
51+
52+
func toDotID(v string) string {
53+
// From DOT language reference:
54+
// An ID is one of the following:
55+
// Any string of alphabetic ([a-zA-Z\200-\377]) characters, underscores ('_') or digits([0-9]), not beginning with a digit;
56+
// a numeral [-]?(.[0-9]⁺ | [0-9]⁺(.[0-9]*)? );
57+
// any double-quoted string ("...") possibly containing escaped quotes (\");
58+
// an HTML string (<...>).
59+
// The best option to convert a `Name` or an `EntityUid` is to use double-quoted string.
60+
// The `strconv.Quote` function should be sufficient for our purpose.
61+
return strconv.Quote(v)
62+
}
63+
64+
func getEntitiesByEntityType(entities iter.Seq[types.Entity]) map[types.EntityType][]types.Entity {
65+
entitiesByType := map[types.EntityType][]types.Entity{}
66+
for entity := range entities {
67+
euid := entity.UID
68+
entityType := euid.Type
69+
if entities, ok := entitiesByType[entityType]; ok {
70+
entitiesByType[entityType] = append(entities, entity)
71+
} else {
72+
entitiesByType[entityType] = []types.Entity{entity}
73+
}
74+
}
75+
return entitiesByType
76+
}

x/exp/dot/dot_test.go

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
package dot
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"maps"
7+
"strconv"
8+
"strings"
9+
"testing"
10+
11+
"github.com/cedar-policy/cedar-go/types"
12+
)
13+
14+
func TestWrite(t *testing.T) {
15+
t.Run("WritesNodesAndEdges", func(t *testing.T) {
16+
var buf bytes.Buffer
17+
18+
// Build a small graph:
19+
// Group::admins (no parents)
20+
// User::alice (parent = Group::admins)
21+
// User::bob (no parents)
22+
groupUID := types.NewEntityUID("Group", types.String("admins"))
23+
aliceUID := types.NewEntityUID("User", types.String("alice"))
24+
bobUID := types.NewEntityUID("User", types.String("bob"))
25+
26+
entities := types.EntityMap{}
27+
entities[groupUID] = types.Entity{
28+
UID: groupUID,
29+
Parents: types.NewEntityUIDSet(), // no parents
30+
}
31+
entities[aliceUID] = types.Entity{
32+
UID: aliceUID,
33+
Parents: types.NewEntityUIDSet(groupUID), // parent is group
34+
}
35+
entities[bobUID] = types.Entity{
36+
UID: bobUID,
37+
Parents: types.NewEntityUIDSet(), // no parents
38+
}
39+
40+
if err := Write(&buf, maps.Values(entities)); err != nil {
41+
t.Fatalf("ToDotStr returned error: %v", err)
42+
}
43+
out := buf.String()
44+
45+
// Basic prelude should be present
46+
if !strings.Contains(out, "strict digraph") {
47+
t.Fatalf("output missing digraph prelude: %q", out)
48+
}
49+
50+
// Each entity should be present as a node with a quoted label matching the ID
51+
expectedGroupNode := fmt.Sprintf("\t\t%q [label=%q]\n", groupUID, groupUID.ID)
52+
if !strings.Contains(out, expectedGroupNode) {
53+
t.Errorf("expected group node line %q not found in output:\n%s", expectedGroupNode, out)
54+
}
55+
56+
expectedAliceNode := fmt.Sprintf("\t\t%q [label=%q]\n", aliceUID, aliceUID.ID)
57+
if !strings.Contains(out, expectedAliceNode) {
58+
t.Errorf("expected alice node line %q not found in output:\n%s", expectedAliceNode, out)
59+
}
60+
61+
expectedBobNode := fmt.Sprintf("\t\t%q [label=%q]\n", bobUID, bobUID.ID)
62+
if !strings.Contains(out, expectedBobNode) {
63+
t.Errorf("expected bob node line %q not found in output:\n%s", expectedBobNode, out)
64+
}
65+
66+
// Edge from alice to group should be present
67+
expectedEdge := fmt.Sprintf("\t%q -> %q\n", aliceUID, groupUID)
68+
if !strings.Contains(out, expectedEdge) {
69+
t.Errorf("expected edge %q not found in output:\n%s", expectedEdge, out)
70+
}
71+
})
72+
73+
t.Run("NoEdgesWhenNoParents", func(t *testing.T) {
74+
var buf bytes.Buffer
75+
76+
// Two entities of different types with no parents; output must contain nodes but no edges
77+
uidA := types.NewEntityUID("TypeA", types.String("a1"))
78+
uidB := types.NewEntityUID("TypeB", types.String("b1"))
79+
80+
entities := types.EntityMap{
81+
uidA: {UID: uidA, Parents: types.NewEntityUIDSet()},
82+
uidB: {UID: uidB, Parents: types.NewEntityUIDSet()},
83+
}
84+
85+
if err := Write(&buf, maps.Values(entities)); err != nil {
86+
t.Fatalf("ToDotStr returned error: %v", err)
87+
}
88+
out := buf.String()
89+
90+
// Ensure nodes exist
91+
if !strings.Contains(out, strconv.Quote(uidA.String())) {
92+
t.Errorf("expected node for uidA %q not present", uidA.String())
93+
}
94+
if !strings.Contains(out, strconv.Quote(uidB.String())) {
95+
t.Errorf("expected node for uidB %q not present", uidB.String())
96+
}
97+
98+
// Ensure there are no edges in the graph
99+
if strings.Contains(out, "->") {
100+
t.Errorf("did not expect any edges, but found some in output:\n%s", out)
101+
}
102+
})
103+
104+
t.Run("WriterFailure", func(t *testing.T) {
105+
// Build entities with multiple types and parents to trigger all write paths:
106+
// - prelude write
107+
// - subgraph header write (first type)
108+
// - node write (first type)
109+
// - subgraph close write (first type)
110+
// - subgraph header write (second type)
111+
// - node write (second type)
112+
// - subgraph close write (second type)
113+
// - edge write
114+
// - final close write
115+
groupUID := types.NewEntityUID("Group", types.String("admins"))
116+
aliceUID := types.NewEntityUID("User", types.String("alice"))
117+
bobUID := types.NewEntityUID("User", types.String("bob"))
118+
119+
entities := types.EntityMap{
120+
groupUID: {UID: groupUID, Parents: types.NewEntityUIDSet()},
121+
aliceUID: {UID: aliceUID, Parents: types.NewEntityUIDSet(groupUID)},
122+
bobUID: {UID: bobUID, Parents: types.NewEntityUIDSet()},
123+
}
124+
125+
// Test each failure point by allowing N successful writes before failing
126+
testCases := []struct {
127+
name string
128+
allowedWrites int
129+
expectedErrorMsg string
130+
}{
131+
{"FailOnPrelude", 0, "write failed"},
132+
{"FailOnFirstSubgraphHeader", 1, "write failed"},
133+
{"FailOnFirstNodeWrite", 2, "write failed"},
134+
{"FailOnSecondNodeWrite", 3, "write failed"},
135+
{"FailOnFirstSubgraphClose", 4, "write failed"},
136+
{"FailOnSecondSubgraphHeader", 5, "write failed"},
137+
{"FailOnSecondTypeNodeWrite", 6, "write failed"},
138+
{"FailOnSecondSubgraphClose", 7, "write failed"},
139+
{"FailOnEdgeWrite", 8, "write failed"},
140+
{"FailOnFinalClose", 9, "write failed"},
141+
}
142+
143+
for _, tc := range testCases {
144+
t.Run(tc.name, func(t *testing.T) {
145+
failingWriter := &failAfterNWriter{allowedWrites: tc.allowedWrites}
146+
err := Write(failingWriter, maps.Values(entities))
147+
if err == nil {
148+
t.Fatal("expected Write to return error when writer fails, got nil")
149+
}
150+
if !strings.Contains(err.Error(), tc.expectedErrorMsg) {
151+
t.Errorf("expected error message to contain %q, got: %v", tc.expectedErrorMsg, err)
152+
}
153+
})
154+
}
155+
})
156+
}
157+
158+
// failAfterNWriter is a writer that fails after N successful writes
159+
type failAfterNWriter struct {
160+
allowedWrites int
161+
writeCount int
162+
}
163+
164+
func (f *failAfterNWriter) Write(p []byte) (n int, err error) {
165+
if f.writeCount >= f.allowedWrites {
166+
return 0, fmt.Errorf("write failed")
167+
}
168+
f.writeCount++
169+
return len(p), nil
170+
}

0 commit comments

Comments
 (0)