Skip to content

Commit e96e3eb

Browse files
committed
[TT-16013] fix bug gateway panics when validating jwt claims
1 parent 26c5240 commit e96e3eb

13 files changed

Lines changed: 532 additions & 333 deletions

apidef/oas/authentication.go

Lines changed: 177 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package oas
33
import (
44
"encoding/json"
55
"fmt"
6+
"iter"
67
"reflect"
78
"sort"
9+
"sync"
810
"time"
911

1012
"github.com/getkin/kin-openapi/openapi3"
@@ -76,7 +78,7 @@ type Authentication struct {
7678
Custom *CustomPluginAuthentication `bson:"custom,omitempty" json:"custom,omitempty"`
7779

7880
// SecuritySchemes contains security schemes definitions.
79-
SecuritySchemes SecuritySchemes `bson:"securitySchemes,omitempty" json:"securitySchemes,omitempty"`
81+
SecuritySchemes *SecuritySchemes `bson:"securitySchemes,omitempty" json:"securitySchemes,omitempty"`
8082

8183
// CustomKeyLifetime contains configuration for the maximum retention period for access tokens.
8284
CustomKeyLifetime *CustomKeyLifetime `bson:"customKeyLifetime,omitempty" json:"customKeyLifetime,omitempty"`
@@ -229,73 +231,209 @@ func (a *Authentication) ExtractTo(api *apidef.APIDefinition) {
229231
a.CustomKeyLifetime.ExtractTo(api)
230232
}
231233

232-
// SecuritySchemes holds security scheme values, filled with Import().
233-
type SecuritySchemes map[string]interface{}
234+
// SecuritySchemes never use the zero value of the type, always use the constructor.
235+
type SecuritySchemes struct {
236+
container map[string]interface{}
237+
mutex sync.RWMutex
238+
}
239+
240+
func NewSecuritySchemes() *SecuritySchemes {
241+
return &SecuritySchemes{container: make(map[string]interface{})}
242+
}
243+
244+
func (ss *SecuritySchemes) Set(key string, value interface{}) {
245+
if ss == nil {
246+
return
247+
}
248+
ss.mutex.Lock()
249+
defer ss.mutex.Unlock()
250+
251+
if ss.container == nil {
252+
ss.container = make(map[string]interface{})
253+
}
254+
ss.container[key] = value
255+
}
256+
257+
func (ss *SecuritySchemes) Get(key string) (interface{}, bool) {
258+
if ss == nil {
259+
return nil, false
260+
}
261+
ss.mutex.RLock()
262+
defer ss.mutex.RUnlock()
263+
264+
if ss.container == nil {
265+
return nil, false
266+
}
267+
value, ok := ss.container[key]
268+
return value, ok
269+
}
270+
271+
func (ss *SecuritySchemes) Delete(key string) {
272+
if ss == nil {
273+
return
274+
}
275+
ss.mutex.Lock()
276+
defer ss.mutex.Unlock()
277+
278+
if ss.container == nil {
279+
return
280+
}
281+
delete(ss.container, key)
282+
}
283+
284+
func (ss *SecuritySchemes) Len() int {
285+
if ss == nil {
286+
return 0
287+
}
288+
ss.mutex.RLock()
289+
defer ss.mutex.RUnlock()
290+
291+
return len(ss.container)
292+
}
293+
294+
// Iter returns a snapshot iterator: it copies entries under RLock,
295+
// then releases the lock before calling user code.
296+
func (ss *SecuritySchemes) Iter() iter.Seq2[string, interface{}] {
297+
return func(yield func(string, interface{}) bool) {
298+
if ss == nil {
299+
return
300+
}
301+
302+
ss.mutex.RLock()
303+
if ss.container == nil {
304+
ss.mutex.RUnlock()
305+
return
306+
}
307+
308+
snap := make([]struct {
309+
k string
310+
v interface{}
311+
}, 0, len(ss.container))
312+
for k, v := range ss.container {
313+
snap = append(snap, struct {
314+
k string
315+
v interface{}
316+
}{k, v})
317+
}
318+
ss.mutex.RUnlock()
319+
320+
for _, e := range snap {
321+
if !yield(e.k, e.v) {
322+
return
323+
}
324+
}
325+
}
326+
}
327+
328+
// MarshalJSON implements json.Marshaler.
329+
// It snapshots the map under RLock, unlocks, then marshals.
330+
func (ss *SecuritySchemes) MarshalJSON() ([]byte, error) {
331+
if ss == nil {
332+
return []byte(`{}`), nil
333+
}
334+
335+
ss.mutex.RLock()
336+
defer ss.mutex.RUnlock()
337+
338+
if ss.container == nil {
339+
return []byte(`{}`), nil
340+
}
341+
342+
return json.Marshal(ss.container)
343+
}
344+
345+
// UnmarshalJSON implements json.Unmarshaler.
346+
// It builds a temporary map, then replaces the internal map under Lock.
347+
func (ss *SecuritySchemes) UnmarshalJSON(b []byte) error {
348+
if string(b) == "null" {
349+
if ss != nil {
350+
ss.mutex.Lock()
351+
ss.container = make(map[string]interface{})
352+
ss.mutex.Unlock()
353+
}
354+
return nil
355+
}
356+
357+
tmp := make(map[string]interface{})
358+
if err := json.Unmarshal(b, &tmp); err != nil {
359+
return err
360+
}
361+
362+
if ss != nil {
363+
ss.mutex.Lock()
364+
ss.container = tmp
365+
ss.mutex.Unlock()
366+
}
367+
return nil
368+
}
234369

235370
// SecurityScheme defines an Importer interface for security schemes.
236371
type SecurityScheme interface {
237372
Import(nativeSS *openapi3.SecurityScheme, enable bool)
238373
}
239374

240-
// Import takes the openapi3.SecurityScheme as argument and applies it to the receiver. The
241-
// SecuritySchemes receiver is a map, so modification of the receiver is enabled, regardless
242-
// of the fact that the receiver isn't a pointer type. The map is a pointer type itself.
243-
func (ss SecuritySchemes) Import(name string, nativeSS *openapi3.SecurityScheme, enable bool) error {
375+
// Import takes the openapi3.SecurityScheme as argument and applies it to the receiver.
376+
// The SecuritySchemes type uses internal synchronization to safely modify its contents.
377+
func (ss *SecuritySchemes) Import(name string, nativeSS *openapi3.SecurityScheme, enable bool) error {
378+
if ss == nil {
379+
return fmt.Errorf("SecuritySchemes is nil")
380+
}
381+
244382
switch {
245383
case nativeSS.Type == typeAPIKey:
246384
token := &Token{}
247-
if ss[name] == nil {
248-
ss[name] = token
385+
scheme, exists := ss.Get(name)
386+
if !exists {
387+
ss.Set(name, token)
388+
}
389+
if tokenVal, ok := scheme.(*Token); ok {
390+
token = tokenVal
249391
} else {
250-
if tokenVal, ok := ss[name].(*Token); ok {
251-
token = tokenVal
252-
} else {
253-
toStructIfMap(ss[name], token)
254-
}
392+
toStructIfMap(scheme, token)
255393
}
256394

257395
token.Enabled = &enable
258396
case nativeSS.Type == typeHTTP && nativeSS.Scheme == schemeBearer && nativeSS.BearerFormat == bearerFormatJWT:
259397
jwt := &JWT{}
260-
if ss[name] == nil {
261-
ss[name] = jwt
398+
scheme, exists := ss.Get(name)
399+
if !exists {
400+
ss.Set(name, jwt)
401+
}
402+
if jwtVal, ok := scheme.(*JWT); ok {
403+
jwt = jwtVal
262404
} else {
263-
if jwtVal, ok := ss[name].(*JWT); ok {
264-
jwt = jwtVal
265-
} else {
266-
toStructIfMap(ss[name], jwt)
267-
}
405+
toStructIfMap(scheme, jwt)
268406
}
269407

270408
jwt.Import(enable)
271409
case nativeSS.Type == typeHTTP && nativeSS.Scheme == schemeBasic:
272410
basic := &Basic{}
273-
if ss[name] == nil {
274-
ss[name] = basic
411+
scheme, exists := ss.Get(name)
412+
if !exists {
413+
ss.Set(name, basic)
414+
}
415+
if basicVal, ok := scheme.(*Basic); ok {
416+
basic = basicVal
275417
} else {
276-
if basicVal, ok := ss[name].(*Basic); ok {
277-
basic = basicVal
278-
} else {
279-
toStructIfMap(ss[name], basic)
280-
}
418+
toStructIfMap(scheme, basic)
281419
}
282420

283421
basic.Import(enable)
284422
case nativeSS.Type == typeOAuth2:
285423
oauth := &OAuth{}
286-
if ss[name] == nil {
287-
ss[name] = oauth
424+
scheme, exists := ss.Get(name)
425+
if !exists {
426+
ss.Set(name, oauth)
427+
}
428+
if oauthVal, ok := scheme.(*OAuth); ok {
429+
oauth = oauthVal
288430
} else {
289-
if oauthVal, ok := ss[name].(*OAuth); ok {
290-
oauth = oauthVal
291-
} else {
292-
toStructIfMap(ss[name], oauth)
293-
}
431+
toStructIfMap(scheme, oauth)
294432
}
295433

296434
oauth.Import(enable)
297435
default:
298-
return fmt.Errorf(unsupportedSecuritySchemeFmt, name)
436+
return fmt.Errorf("unsupported securityScheme type: %s", nativeSS.Type)
299437
}
300438

301439
return nil
@@ -317,15 +455,15 @@ func baseIdentityProviderPrecedence(authType apidef.AuthTypeEnum) int {
317455
}
318456

319457
// GetBaseIdentityProvider returns the identity provider by precedence from SecuritySchemes.
320-
func (ss SecuritySchemes) GetBaseIdentityProvider() (res apidef.AuthTypeEnum) {
321-
if len(ss) < 2 {
458+
func (ss *SecuritySchemes) GetBaseIdentityProvider() (res apidef.AuthTypeEnum) {
459+
if ss == nil || ss.Len() < 2 {
322460
return
323461
}
324462

325463
resBaseIdentityProvider := baseIdentityProviderPrecedence(apidef.AuthTypeNone)
326464
res = apidef.OAuthKey
327465

328-
for _, scheme := range ss {
466+
for _, scheme := range ss.Iter() {
329467
if _, ok := scheme.(*Token); ok {
330468
return apidef.AuthToken
331469
}

apidef/oas/default.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func (s *OAS) importAuthentication(enable bool) error {
190190

191191
tykSecuritySchemes := authentication.SecuritySchemes
192192
if tykSecuritySchemes == nil {
193-
tykSecuritySchemes = make(SecuritySchemes)
193+
tykSecuritySchemes = NewSecuritySchemes()
194194
authentication.SecuritySchemes = tykSecuritySchemes
195195
}
196196

0 commit comments

Comments
 (0)