11package validate
22
33import (
4+ "cmp"
45 "fmt"
56 "slices"
67 "strings"
@@ -51,7 +52,7 @@ func (typeExtension) isCedarType() { _ = "hack for code coverage" }
5152func typeIncompatErr (a , b cedarType ) * typeIncompatError {
5253 nameA := cedarTypeName (a )
5354 nameB := cedarTypeName (b )
54- if cedarTypeSortKey ( a ) > cedarTypeSortKey ( b ) {
55+ if compareCedarType ( a , b ) > 0 {
5556 nameA , nameB = nameB , nameA
5657 }
5758 return & typeIncompatError {msg : fmt .Sprintf ("the types %s and %s are not compatible" , nameA , nameB )}
@@ -64,14 +65,7 @@ func typeIncompatErrMulti(types []cedarType) *typeIncompatError {
6465 sorted := make ([]cedarType , len (types ))
6566 copy (sorted , types )
6667 slices .SortFunc (sorted , func (a , b cedarType ) int {
67- ka , kb := cedarTypeSortKey (a ), cedarTypeSortKey (b )
68- if ka < kb {
69- return - 1
70- }
71- if ka > kb {
72- return 1
73- }
74- return 0
68+ return compareCedarType (a , b )
7569 })
7670 names := make ([]string , len (sorted ))
7771 for i , t := range sorted {
@@ -99,57 +93,104 @@ func typeIncompatErrMulti(types []cedarType) *typeIncompatError {
9993 return & typeIncompatError {msg : sb .String ()}
10094}
10195
102- // cedarTypeSortKey returns a sort key for ordering types in error messages.
103- // Matches Rust's structural type ordering (True < False < Never < Long < String < Set < Record < Entity < Extension).
104- func cedarTypeSortKey (t cedarType ) string {
105- switch tv := t .(type ) {
96+ func compareCedarType (a , b cedarType ) int {
97+ ak , bk := cedarTypeKindRank (a ), cedarTypeKindRank (b )
98+ if ak != bk {
99+ return ak - bk
100+ }
101+
102+ if av , ok := a .(typeSet ); ok {
103+ return compareCedarType (av .element , b .(typeSet ).element )
104+ }
105+ if av , ok := a .(typeRecord ); ok {
106+ return compareRecordTypes (av , b .(typeRecord ))
107+ }
108+ if av , ok := a .(typeEntity ); ok {
109+ return compareEntityLUB (av .lub , b .(typeEntity ).lub )
110+ }
111+ return strings .Compare (cedarTypeName (a ), cedarTypeName (b ))
112+ }
113+
114+ func cedarTypeKindRank (t cedarType ) int {
115+ switch t .(type ) {
106116 case typeTrue :
107- return "0a"
117+ return 0
108118 case typeFalse :
109- return "0b"
119+ return 1
110120 case typeBool :
111- return "0c"
112- case typeNever :
113- return "1"
121+ return 2
114122 case typeLong :
115- return "2"
123+ return 4
124+ case typeNever :
125+ return 3
116126 case typeString :
117- return "3"
127+ return 5
118128 case typeSet :
119- return "4:" + cedarTypeSortKey ( tv . element )
129+ return 6
120130 case typeRecord :
121- // Sort by attribute keys/types (matches Rust BTreeMap ordering)
122- key := "5"
123- keys := make ([]string , 0 , len (tv .attrs ))
124- for k := range tv .attrs {
125- keys = append (keys , string (k ))
126- }
127- slices .Sort (keys )
128- for _ , k := range keys {
129- at := tv .attrs [types .String (k )]
130- key += ":" + k + ":" + cedarTypeSortKey (at .typ )
131- }
132- return key
131+ return 7
133132 case typeEntity :
134- return "6:" + cedarEntityTypeName ( tv . lub )
133+ return 8
135134 case typeExtension :
135+ return 9
136+ default :
137+ return - 1
136138 }
137- return "7:" + string (t .(typeExtension ).name )
139+ }
140+
141+ func compareRecordTypes (a , b typeRecord ) int {
142+ ak := make ([]string , 0 , len (a .attrs ))
143+ for k := range a .attrs {
144+ ak = append (ak , string (k ))
145+ }
146+ bk := make ([]string , 0 , len (b .attrs ))
147+ for k := range b .attrs {
148+ bk = append (bk , string (k ))
149+ }
150+ slices .Sort (ak )
151+ slices .Sort (bk )
152+
153+ n := len (ak )
154+ if len (bk ) < n {
155+ n = len (bk )
156+ }
157+ for i := 0 ; i < n ; i ++ {
158+ if c := strings .Compare (ak [i ], bk [i ]); c != 0 {
159+ return c
160+ }
161+ aat := a .attrs [types .String (ak [i ])]
162+ bat := b .attrs [types .String (bk [i ])]
163+ if c := compareCedarType (aat .typ , bat .typ ); c != 0 {
164+ return c
165+ }
166+ }
167+ return cmp .Compare (len (ak ), len (bk ))
168+ }
169+
170+ func compareEntityLUB (a , b entityLUB ) int {
171+ n := min (len (a .elements ), len (b .elements ))
172+ for i := 0 ; i < n ; i ++ {
173+ as , bs := string (a .elements [i ]), string (b .elements [i ])
174+ if c := strings .Compare (as , bs ); c != 0 {
175+ return c
176+ }
177+ }
178+ return cmp .Compare (len (a .elements ), len (b .elements ))
138179}
139180
140181// cedarTypeName returns the Rust Cedar display name for a type.
141182func cedarTypeName (t cedarType ) string {
142183 switch tv := t .(type ) {
143184 case typeNever :
144- return "__cedar::internal::Any"
185+ return "__cedar::internal::Never"
186+ case typeLong :
187+ return "Long"
145188 case typeTrue :
146189 return "__cedar::internal::True"
147190 case typeFalse :
148191 return "__cedar::internal::False"
149192 case typeBool :
150193 return "Bool"
151- case typeLong :
152- return "Long"
153194 case typeString :
154195 return "String"
155196 case typeSet :
@@ -159,14 +200,13 @@ func cedarTypeName(t cedarType) string {
159200 case typeEntity :
160201 return cedarEntityTypeName (tv .lub )
161202 case typeExtension :
203+ return string (tv .name )
204+ default :
205+ return "?"
162206 }
163- return string (t .(typeExtension ).name )
164207}
165208
166209func cedarEntityTypeName (lub entityLUB ) string {
167- if len (lub .elements ) == 0 {
168- return "__cedar::internal::AnyEntity"
169- }
170210 if len (lub .elements ) == 1 {
171211 return string (lub .elements [0 ])
172212 }
@@ -192,6 +232,9 @@ func cedarRecordTypeName(r typeRecord) string {
192232 for _ , k := range keys {
193233 at := r .attrs [types .String (k )]
194234 sb .WriteString (k )
235+ if ! at .required {
236+ sb .WriteRune ('?' )
237+ }
195238 sb .WriteString (": " )
196239 sb .WriteString (cedarTypeName (at .typ ))
197240 sb .WriteRune (',' )
@@ -254,14 +297,13 @@ func (v *Validator) isSubtype(a, b cedarType) bool {
254297
255298// leastUpperBound computes the LUB of two types.
256299func (v * Validator ) leastUpperBound (a , b cedarType ) (cedarType , error ) {
257- if _ , ok := a .(typeNever ); ok {
258- return b , nil
259- }
260300 if _ , ok := b .(typeNever ); ok {
261301 return a , nil
262302 }
263303
264304 switch av := a .(type ) {
305+ case typeNever :
306+ return b , nil
265307 case typeTrue :
266308 switch b .(type ) {
267309 case typeTrue :
@@ -279,10 +321,8 @@ func (v *Validator) leastUpperBound(a, b cedarType) (cedarType, error) {
279321 case typeNever , typeLong , typeString , typeSet , typeRecord , typeEntity , typeExtension :
280322 }
281323 case typeBool :
282- switch b .(type ) {
283- case typeTrue , typeFalse , typeBool :
324+ if isBoolType (b ) {
284325 return typeBool {}, nil
285- case typeNever , typeLong , typeString , typeSet , typeRecord , typeEntity , typeExtension :
286326 }
287327 case typeLong :
288328 if _ , ok := b .(typeLong ); ok {
@@ -312,8 +352,6 @@ func (v *Validator) leastUpperBound(a, b cedarType) (cedarType, error) {
312352 if bv , ok := b .(typeExtension ); ok && av .name == bv .name {
313353 return av , nil
314354 }
315- case typeNever :
316- // Already handled above; unreachable.
317355 }
318356
319357 return nil , fmt .Errorf ("incompatible types for least upper bound" )
0 commit comments