@@ -16,8 +16,10 @@ package replay
1616
1717import (
1818 "encoding/json"
19+ "errors"
1920 "fmt"
2021 "sort"
22+ "strings"
2123 "testing"
2224
2325 "github.com/stretchr/testify/assert"
@@ -37,6 +39,23 @@ type testingT interface {
3739//
3840// {"\\": x} matches only JSON documents strictly equal to x. This pattern essentially escapes the sub-tree, for example
3941// use {"\\": "*"} to match only the literal string "*".
42+ //
43+ // An object pattern {"key1": "pattern1", "key2": "pattern2"} matches objects in a natural manner. By default it will
44+ // only match objects with the exact set of keys specified. To tolerate extraneous keys, a catch-all pattern can be
45+ // specified as follows, to match against all unspecified keys:
46+ //
47+ // {"key1": "pattern1", "key2": "pattern2", "*": "catch-all-pattern"}
48+ //
49+ // In particular this can be used to ignore all extraneous keys:
50+ //
51+ // {"key1": "pattern1", "key2": "pattern2", "*": "*"}
52+ //
53+ // It is possible to escape keys in an object pattern by prefixing them with "\\", for example this pattern:
54+ //
55+ // {"\\*": "foo"}
56+ //
57+ // This pattern will only match the object {"*": "foo"}, that is the wildcard is interpreted literally and not as the
58+ // catch-all pattern.
4059func AssertJSONMatchesPattern (
4160 t * testing.T ,
4261 expectedPattern json.RawMessage ,
@@ -63,89 +82,133 @@ func assertJSONMatchesPattern(
6382 require .NoError (t , err )
6483 }
6584
66- detectEscape := func (m map [string ]interface {}) (interface {}, bool ) {
67- if len (m ) != 1 {
68- return nil , false
85+ match (t , "#" , p , a )
86+ }
87+
88+ func match (t testingT , path string , p , a interface {}) {
89+ switch pp := p .(type ) {
90+ case string :
91+ if pp != "*" {
92+ assertJSONEquals (t , path , p , a )
6993 }
70- for k , v := range m {
71- if k == "\\ " {
72- return v , true
73- }
94+ case []interface {}:
95+ aa , ok := a .([]interface {})
96+ if ! ok {
97+ t .Errorf ("[%s]: expected an array, but got %s" , path , prettyJSON (t , a ))
98+ return
7499 }
75- return nil , false
100+ if len (aa ) != len (pp ) {
101+ t .Errorf ("[%s]: expected an array of length %d, but got %s" ,
102+ path , len (pp ), prettyJSON (t , a ))
103+ return
104+ }
105+ for i , pv := range pp {
106+ av := aa [i ]
107+ match (t , fmt .Sprintf ("%s[%d]" , path , i ), pv , av )
108+ }
109+ case map [string ]interface {}:
110+ matchObjectPattern (t , path , pp , a )
111+ default :
112+ assertJSONEquals (t , path , p , a )
76113 }
114+ }
77115
78- var match func (path string , p , a interface {})
79- match = func (path string , p , a interface {}) {
80- switch pp := p .(type ) {
81- case string :
82- if pp != "*" {
83- assertJSONEquals (t , path , p , a )
84- }
85- case []interface {}:
86- aa , ok := a .([]interface {})
87- if ! ok {
88- t .Errorf ("[%s]: expected an array, but got %s" , path , prettyJSON (t , a ))
89- return
90- }
91- if len (aa ) != len (pp ) {
92- t .Errorf ("[%s]: expected an array of length %d, but got %s" ,
93- path , len (pp ), prettyJSON (t , a ))
94- return
95- }
96- for i , pv := range pp {
97- av := aa [i ]
98- match (fmt .Sprintf ("%s[%d]" , path , i ), pv , av )
99- }
100- case map [string ]interface {}:
101- if esc , isEsc := detectEscape (pp ); isEsc {
102- assertJSONEquals (t , path , esc , a )
103- return
104- }
105-
106- aa , ok := a .(map [string ]interface {})
107- if ! ok {
108- t .Errorf ("[%s]: expected an object, but got %s" , path , prettyJSON (t , a ))
109- return
110- }
111-
112- seenKeys := map [string ]bool {}
113- allKeys := []string {}
114-
115- for k := range pp {
116- if ! seenKeys [k ] {
117- allKeys = append (allKeys , k )
118- }
119- seenKeys [k ] = true
120- }
121-
122- for k := range aa {
123- if ! seenKeys [k ] {
124- allKeys = append (allKeys , k )
125- }
126- seenKeys [k ] = true
127- }
128- sort .Strings (allKeys )
129-
130- for _ , k := range allKeys {
131- pv , gotPV := pp [k ]
132- av , gotAV := aa [k ]
133- subPath := fmt .Sprintf ("%s[%q]" , path , k )
134- switch {
135- case gotPV && gotAV :
136- match (subPath , pv , av )
137- case ! gotPV && gotAV :
138- t .Errorf ("[%s] unexpected value %s" , subPath , prettyJSON (t , av ))
139- case gotPV && ! gotAV :
140- t .Errorf ("[%s] missing a required value" , subPath )
141- }
142- }
143- default :
144- assertJSONEquals (t , path , p , a )
116+ type objectPattern struct {
117+ keyPatterns map [string ]any
118+ catchAllPattern any
119+ hasCatchAllPattern bool
120+ }
121+
122+ func (p * objectPattern ) sortedKeyUnion (value map [string ]any ) []string {
123+ var keys []string
124+ for k := range p .keyPatterns {
125+ keys = append (keys , k )
126+ }
127+ for k := range value {
128+ if _ , seen := p .keyPatterns [k ]; seen {
129+ continue
130+ }
131+ keys = append (keys , k )
132+ }
133+ sort .Strings (keys )
134+ return keys
135+ }
136+
137+ func compileObjectPattern (pattern map [string ]any ) (objectPattern , error ) {
138+ o := objectPattern {
139+ keyPatterns : map [string ]any {},
140+ }
141+
142+ var err error
143+ for k , v := range pattern {
144+ if k == "*" {
145+ o .hasCatchAllPattern = true
146+ o .catchAllPattern = v
147+ continue
148+ }
149+
150+ // Keys in object patterns may be escaped.
151+ cleanKey := strings .TrimPrefix (k , "\\ " )
152+
153+ if _ , conflict := o .keyPatterns [cleanKey ]; conflict {
154+ err = errors .Join (err , fmt .Errorf ("object key pattern %q specified more than once" , cleanKey ))
145155 }
156+
157+ o .keyPatterns [cleanKey ] = v
158+ }
159+
160+ if err != nil {
161+ return objectPattern {}, err
162+ }
163+
164+ return o , nil
165+ }
166+
167+ func matchObjectPattern (t testingT , path string , pattern map [string ]any , value any ) {
168+ if esc , isEsc := detectEscape (pattern ); isEsc {
169+ assertJSONEquals (t , path , esc , value )
170+ return
146171 }
147172
148- match ("#" , p , a )
173+ objPattern , err := compileObjectPattern (pattern )
174+ if err != nil {
175+ t .Errorf ("[%s]: %v" , err )
176+ return
177+ }
178+
179+ aa , ok := value .(map [string ]interface {})
180+ if ! ok {
181+ t .Errorf ("[%s]: expected an object, but got %s" , path , prettyJSON (t , value ))
182+ return
183+ }
184+
185+ for _ , k := range objPattern .sortedKeyUnion (aa ) {
186+ pv , gotPV := objPattern .keyPatterns [k ]
187+ av , gotAV := aa [k ]
188+ subPath := fmt .Sprintf ("%s[%q]" , path , k )
189+ switch {
190+ case gotPV && gotAV :
191+ match (t , subPath , pv , av )
192+ case ! gotPV && gotAV && ! objPattern .hasCatchAllPattern :
193+ t .Errorf ("[%s] unexpected value %s" , subPath , prettyJSON (t , av ))
194+ case ! gotPV && gotAV && objPattern .hasCatchAllPattern :
195+ match (t , subPath , objPattern .catchAllPattern , av )
196+ case gotPV && ! gotAV :
197+ t .Errorf ("[%s] missing a required value" , subPath )
198+ }
199+ }
200+ }
201+
202+ func detectEscape (m map [string ]interface {}) (interface {}, bool ) {
203+ if len (m ) != 1 {
204+ return nil , false
205+ }
206+ for k , v := range m {
207+ if k == "\\ " {
208+ return v , true
209+ }
210+ }
211+ return nil , false
149212}
150213
151214func assertJSONEquals (t testingT , path string , expected , actual interface {}) {
0 commit comments