Skip to content

Commit 946ba3e

Browse files
committed
jwks: refactor RemoteKeySet cache to a map
this refactors RemoteKeySet to cache keys using a map so that keys can be looked up by keyID in constant-time.
1 parent 08563f6 commit 946ba3e

File tree

1 file changed

+35
-21
lines changed

1 file changed

+35
-21
lines changed

oidc/jwks.go

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ type RemoteKeySet struct {
4646
inflight *inflight
4747

4848
// A set of cached keys.
49-
cachedKeys []jose.JSONWebKey
49+
cachedKeys map[string]jose.JSONWebKey
5050
}
5151

5252
// inflight is used to wait on some in-flight request from multiple goroutines.
5353
type inflight struct {
5454
doneCh chan struct{}
5555

56-
keys []jose.JSONWebKey
56+
keys map[string]jose.JSONWebKey
5757
err error
5858
}
5959

@@ -70,14 +70,14 @@ func (i *inflight) wait() <-chan struct{} {
7070
// done can only be called by a single goroutine. It records the result of the
7171
// inflight request and signals other goroutines that the result is safe to
7272
// inspect.
73-
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
73+
func (i *inflight) done(keys map[string]jose.JSONWebKey, err error) {
7474
i.keys = keys
7575
i.err = err
7676
close(i.doneCh)
7777
}
7878

7979
// result cannot be called until the wait() channel has returned a value.
80-
func (i *inflight) result() ([]jose.JSONWebKey, error) {
80+
func (i *inflight) result() (map[string]jose.JSONWebKey, error) {
8181
return i.keys, i.err
8282
}
8383

@@ -102,43 +102,53 @@ func (r *RemoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) (
102102
break
103103
}
104104

105-
keys := r.keysFromCache()
106-
for _, key := range keys {
107-
if keyID == "" || key.KeyID == keyID {
108-
if payload, err := jws.Verify(&key); err == nil {
109-
return payload, nil
110-
}
111-
}
105+
if payload, ok := r.verifyWithKey(jws, keyID); ok {
106+
return payload, nil
112107
}
113-
114108
// If the kid doesn't match, check for new keys from the remote. This is the
115109
// strategy recommended by the spec.
116110
//
117111
// https://openid.net/specs/openid-connect-core-1_0.html#RotateSigKeys
118-
keys, err := r.keysFromRemote(ctx)
112+
_, err := r.keysFromRemote(ctx)
119113
if err != nil {
120114
return nil, fmt.Errorf("fetching keys %v", err)
121115
}
122116

123-
for _, key := range keys {
124-
if keyID == "" || key.KeyID == keyID {
117+
if payload, ok := r.verifyWithKey(jws, keyID); ok {
118+
return payload, nil
119+
}
120+
121+
return nil, errors.New("failed to verify id token signature")
122+
}
123+
124+
// verifyWithKey attempts to verify the jws using the key with keyID from the cache
125+
// if keyID is the empty string, it tries each key in the cache
126+
func (r *RemoteKeySet) verifyWithKey(jws *jose.JSONWebSignature, keyID string) (payload []byte, ok bool) {
127+
if keyID == "" {
128+
for _, key := range r.keysFromCache() {
125129
if payload, err := jws.Verify(&key); err == nil {
126-
return payload, nil
130+
return payload, true
131+
}
132+
}
133+
} else {
134+
if key, ok := r.keysFromCache()[keyID]; ok {
135+
if payload, err := jws.Verify(&key); err == nil {
136+
return payload, true
127137
}
128138
}
129139
}
130-
return nil, errors.New("failed to verify id token signature")
140+
return nil, false
131141
}
132142

133-
func (r *RemoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
143+
func (r *RemoteKeySet) keysFromCache() (keys map[string]jose.JSONWebKey) {
134144
r.mu.Lock()
135145
defer r.mu.Unlock()
136146
return r.cachedKeys
137147
}
138148

139149
// keysFromRemote syncs the key set from the remote set, records the values in the
140150
// cache, and returns the key set.
141-
func (r *RemoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
151+
func (r *RemoteKeySet) keysFromRemote(ctx context.Context) (map[string]jose.JSONWebKey, error) {
142152
// Need to lock to inspect the inflight request field.
143153
r.mu.Lock()
144154
// If there's not a current inflight request, create one.
@@ -178,7 +188,7 @@ func (r *RemoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, e
178188
}
179189
}
180190

181-
func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) {
191+
func (r *RemoteKeySet) updateKeys() (map[string]jose.JSONWebKey, error) {
182192
req, err := http.NewRequest("GET", r.jwksURL, nil)
183193
if err != nil {
184194
return nil, fmt.Errorf("oidc: can't create request: %v", err)
@@ -204,5 +214,9 @@ func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) {
204214
if err != nil {
205215
return nil, fmt.Errorf("oidc: failed to decode keys: %v %s", err, body)
206216
}
207-
return keySet.Keys, nil
217+
keys := make(map[string]jose.JSONWebKey)
218+
for _, key := range keySet.Keys {
219+
keys[key.KeyID] = key
220+
}
221+
return keys, nil
208222
}

0 commit comments

Comments
 (0)