11package validate
22
33import (
4+ "cmp"
45 "fmt"
56 "slices"
67 "strings"
@@ -36,6 +37,28 @@ type entityAttrSource struct {
3637type typeEntity struct { lub entityLUB } // Entity with LUB of types
3738type typeExtension struct { name types.Ident } // Extension type (ipaddr, decimal, etc.)
3839
40+ var cedarTypeKindRanks = map [string ]int {
41+ fmt .Sprintf ("%T" , typeTrue {}): 0 ,
42+ fmt .Sprintf ("%T" , typeFalse {}): 1 ,
43+ fmt .Sprintf ("%T" , typeBool {}): 2 ,
44+ fmt .Sprintf ("%T" , typeNever {}): 3 ,
45+ fmt .Sprintf ("%T" , typeLong {}): 4 ,
46+ fmt .Sprintf ("%T" , typeString {}): 5 ,
47+ fmt .Sprintf ("%T" , typeSet {}): 6 ,
48+ fmt .Sprintf ("%T" , typeRecord {}): 7 ,
49+ fmt .Sprintf ("%T" , typeEntity {}): 8 ,
50+ fmt .Sprintf ("%T" , typeExtension {}): 9 ,
51+ }
52+
53+ var cedarPrimitiveTypeNames = map [string ]string {
54+ fmt .Sprintf ("%T" , typeNever {}): "__cedar::internal::Any" ,
55+ fmt .Sprintf ("%T" , typeTrue {}): "__cedar::internal::True" ,
56+ fmt .Sprintf ("%T" , typeFalse {}): "__cedar::internal::False" ,
57+ fmt .Sprintf ("%T" , typeBool {}): "Bool" ,
58+ fmt .Sprintf ("%T" , typeLong {}): "Long" ,
59+ fmt .Sprintf ("%T" , typeString {}): "String" ,
60+ }
61+
3962func (typeNever ) isCedarType () { _ = "hack for code coverage" }
4063func (typeTrue ) isCedarType () { _ = "hack for code coverage" }
4164func (typeFalse ) isCedarType () { _ = "hack for code coverage" }
@@ -51,7 +74,7 @@ func (typeExtension) isCedarType() { _ = "hack for code coverage" }
5174func typeIncompatErr (a , b cedarType ) * typeIncompatError {
5275 nameA := cedarTypeName (a )
5376 nameB := cedarTypeName (b )
54- if cedarTypeSortKey ( a ) > cedarTypeSortKey ( b ) {
77+ if compareCedarType ( a , b ) > 0 {
5578 nameA , nameB = nameB , nameA
5679 }
5780 return & typeIncompatError {msg : fmt .Sprintf ("the types %s and %s are not compatible" , nameA , nameB )}
@@ -64,14 +87,7 @@ func typeIncompatErrMulti(types []cedarType) *typeIncompatError {
6487 sorted := make ([]cedarType , len (types ))
6588 copy (sorted , types )
6689 slices .SortFunc (sorted , func (a , b cedarType ) int {
67- ka , kb := cedarTypeSortKey (a ), cedarTypeSortKey (b )
68- if ka < kb {
69- return - 1
70- }
71- if ka > kb {
72- return 1
73- }
74- return 0
90+ return compareCedarType (a , b )
7591 })
7692 names := make ([]string , len (sorted ))
7793 for i , t := range sorted {
@@ -99,74 +115,85 @@ func typeIncompatErrMulti(types []cedarType) *typeIncompatError {
99115 return & typeIncompatError {msg : sb .String ()}
100116}
101117
102- // cedarTypeSortKey returns a sort key for ordering types in error messages.
103- // Matches Rust's structural type ordering (True < False < Never < Long < String < Set < Record < Entity < Extension).
104- func cedarTypeSortKey (t cedarType ) string {
105- switch tv := t .(type ) {
106- case typeTrue :
107- return "0a"
108- case typeFalse :
109- return "0b"
110- case typeBool :
111- return "0c"
112- case typeNever :
113- return "1"
114- case typeLong :
115- return "2"
116- case typeString :
117- return "3"
118- case typeSet :
119- return "4:" + cedarTypeSortKey (tv .element )
120- case typeRecord :
121- // Sort by attribute keys/types (matches Rust BTreeMap ordering)
122- key := "5"
123- keys := make ([]string , 0 , len (tv .attrs ))
124- for k := range tv .attrs {
125- keys = append (keys , string (k ))
126- }
127- slices .Sort (keys )
128- for _ , k := range keys {
129- at := tv .attrs [types .String (k )]
130- key += ":" + k + ":" + cedarTypeSortKey (at .typ )
131- }
132- return key
133- case typeEntity :
134- return "6:" + cedarEntityTypeName (tv .lub )
135- case typeExtension :
118+ func compareCedarType (a , b cedarType ) int {
119+ ak , bk := cedarTypeKindRank (a ), cedarTypeKindRank (b )
120+ if ak != bk {
121+ return ak - bk
122+ }
123+
124+ if av , ok := a .(typeSet ); ok {
125+ return compareCedarType (av .element , b .(typeSet ).element )
126+ }
127+ if av , ok := a .(typeRecord ); ok {
128+ return compareRecordTypes (av , b .(typeRecord ))
129+ }
130+ if av , ok := a .(typeEntity ); ok {
131+ return compareEntityLUB (av .lub , b .(typeEntity ).lub )
132+ }
133+ return strings .Compare (cedarTypeName (a ), cedarTypeName (b ))
134+ }
135+
136+ func cedarTypeKindRank (t cedarType ) int {
137+ return cedarTypeKindRanks [fmt .Sprintf ("%T" , t )]
138+ }
139+
140+ func compareRecordTypes (a , b typeRecord ) int {
141+ ak := make ([]string , 0 , len (a .attrs ))
142+ for k := range a .attrs {
143+ ak = append (ak , string (k ))
136144 }
137- return "7:" + string (t .(typeExtension ).name )
145+ bk := make ([]string , 0 , len (b .attrs ))
146+ for k := range b .attrs {
147+ bk = append (bk , string (k ))
148+ }
149+ slices .Sort (ak )
150+ slices .Sort (bk )
151+
152+ n := len (ak )
153+ if len (bk ) < n {
154+ n = len (bk )
155+ }
156+ for i := 0 ; i < n ; i ++ {
157+ if c := strings .Compare (ak [i ], bk [i ]); c != 0 {
158+ return c
159+ }
160+ aat := a .attrs [types .String (ak [i ])]
161+ bat := b .attrs [types .String (bk [i ])]
162+ if c := compareCedarType (aat .typ , bat .typ ); c != 0 {
163+ return c
164+ }
165+ }
166+ return cmp .Compare (len (ak ), len (bk ))
167+ }
168+
169+ func compareEntityLUB (a , b entityLUB ) int {
170+ n := min (len (a .elements ), len (b .elements ))
171+ for i := 0 ; i < n ; i ++ {
172+ as , bs := string (a .elements [i ]), string (b .elements [i ])
173+ if c := strings .Compare (as , bs ); c != 0 {
174+ return c
175+ }
176+ }
177+ return cmp .Compare (len (a .elements ), len (b .elements ))
138178}
139179
140180// cedarTypeName returns the Rust Cedar display name for a type.
141181func cedarTypeName (t cedarType ) string {
142182 switch tv := t .(type ) {
143- case typeNever :
144- return "__cedar::internal::Any"
145- case typeTrue :
146- return "__cedar::internal::True"
147- case typeFalse :
148- return "__cedar::internal::False"
149- case typeBool :
150- return "Bool"
151- case typeLong :
152- return "Long"
153- case typeString :
154- return "String"
183+ case typeNever , typeTrue , typeFalse , typeBool , typeLong , typeString :
155184 case typeSet :
156185 return "Set<" + cedarTypeName (tv .element ) + ">"
157186 case typeRecord :
158187 return cedarRecordTypeName (tv )
159188 case typeEntity :
160189 return cedarEntityTypeName (tv .lub )
161190 case typeExtension :
191+ return string (tv .name )
162192 }
163- return string ( t .( typeExtension ). name )
193+ return cedarPrimitiveTypeNames [ fmt . Sprintf ( "%T" , t )]
164194}
165195
166196func cedarEntityTypeName (lub entityLUB ) string {
167- if len (lub .elements ) == 0 {
168- return "__cedar::internal::AnyEntity"
169- }
170197 if len (lub .elements ) == 1 {
171198 return string (lub .elements [0 ])
172199 }
@@ -192,6 +219,9 @@ func cedarRecordTypeName(r typeRecord) string {
192219 for _ , k := range keys {
193220 at := r .attrs [types .String (k )]
194221 sb .WriteString (k )
222+ if ! at .required {
223+ sb .WriteRune ('?' )
224+ }
195225 sb .WriteString (": " )
196226 sb .WriteString (cedarTypeName (at .typ ))
197227 sb .WriteRune (',' )
@@ -254,14 +284,13 @@ func (v *Validator) isSubtype(a, b cedarType) bool {
254284
255285// leastUpperBound computes the LUB of two types.
256286func (v * Validator ) leastUpperBound (a , b cedarType ) (cedarType , error ) {
257- if _ , ok := a .(typeNever ); ok {
258- return b , nil
259- }
260287 if _ , ok := b .(typeNever ); ok {
261288 return a , nil
262289 }
263290
264291 switch av := a .(type ) {
292+ case typeNever :
293+ return b , nil
265294 case typeTrue :
266295 switch b .(type ) {
267296 case typeTrue :
@@ -279,10 +308,8 @@ func (v *Validator) leastUpperBound(a, b cedarType) (cedarType, error) {
279308 case typeNever , typeLong , typeString , typeSet , typeRecord , typeEntity , typeExtension :
280309 }
281310 case typeBool :
282- switch b .(type ) {
283- case typeTrue , typeFalse , typeBool :
311+ if isBoolType (b ) {
284312 return typeBool {}, nil
285- case typeNever , typeLong , typeString , typeSet , typeRecord , typeEntity , typeExtension :
286313 }
287314 case typeLong :
288315 if _ , ok := b .(typeLong ); ok {
@@ -312,8 +339,6 @@ func (v *Validator) leastUpperBound(a, b cedarType) (cedarType, error) {
312339 if bv , ok := b .(typeExtension ); ok && av .name == bv .name {
313340 return av , nil
314341 }
315- case typeNever :
316- // Already handled above; unreachable.
317342 }
318343
319344 return nil , fmt .Errorf ("incompatible types for least upper bound" )
0 commit comments