@@ -46,14 +46,14 @@ type RemoteKeySet struct {
46
46
inflight * inflight
47
47
48
48
// A set of cached keys.
49
- cachedKeys [ ]jose.JSONWebKey
49
+ cachedKeys map [ string ]jose.JSONWebKey
50
50
}
51
51
52
52
// inflight is used to wait on some in-flight request from multiple goroutines.
53
53
type inflight struct {
54
54
doneCh chan struct {}
55
55
56
- keys [ ]jose.JSONWebKey
56
+ keys map [ string ]jose.JSONWebKey
57
57
err error
58
58
}
59
59
@@ -70,14 +70,14 @@ func (i *inflight) wait() <-chan struct{} {
70
70
// done can only be called by a single goroutine. It records the result of the
71
71
// inflight request and signals other goroutines that the result is safe to
72
72
// inspect.
73
- func (i * inflight ) done (keys [ ]jose.JSONWebKey , err error ) {
73
+ func (i * inflight ) done (keys map [ string ]jose.JSONWebKey , err error ) {
74
74
i .keys = keys
75
75
i .err = err
76
76
close (i .doneCh )
77
77
}
78
78
79
79
// result cannot be called until the wait() channel has returned a value.
80
- func (i * inflight ) result () ([ ]jose.JSONWebKey , error ) {
80
+ func (i * inflight ) result () (map [ string ]jose.JSONWebKey , error ) {
81
81
return i .keys , i .err
82
82
}
83
83
@@ -102,43 +102,53 @@ func (r *RemoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) (
102
102
break
103
103
}
104
104
105
- keys := r .keysFromCache ()
106
- for _ , key := range keys {
107
- if keyID == "" || key .KeyID == keyID {
108
- if payload , err := jws .Verify (& key ); err == nil {
109
- return payload , nil
110
- }
111
- }
105
+ if payload , ok := r .verifyWithKey (jws , keyID ); ok {
106
+ return payload , nil
112
107
}
113
-
114
108
// If the kid doesn't match, check for new keys from the remote. This is the
115
109
// strategy recommended by the spec.
116
110
//
117
111
// https://openid.net/specs/openid-connect-core-1_0.html#RotateSigKeys
118
- keys , err := r .keysFromRemote (ctx )
112
+ _ , err := r .keysFromRemote (ctx )
119
113
if err != nil {
120
114
return nil , fmt .Errorf ("fetching keys %v" , err )
121
115
}
122
116
123
- for _ , key := range keys {
124
- if keyID == "" || key .KeyID == keyID {
117
+ if payload , ok := r .verifyWithKey (jws , keyID ); ok {
118
+ return payload , nil
119
+ }
120
+
121
+ return nil , errors .New ("failed to verify id token signature" )
122
+ }
123
+
124
+ // verifyWithKey attempts to verify the jws using the key with keyID from the cache
125
+ // if keyID is the empty string, it tries each key in the cache
126
+ func (r * RemoteKeySet ) verifyWithKey (jws * jose.JSONWebSignature , keyID string ) (payload []byte , ok bool ) {
127
+ if keyID == "" {
128
+ for _ , key := range r .keysFromCache () {
125
129
if payload , err := jws .Verify (& key ); err == nil {
126
- return payload , nil
130
+ return payload , true
131
+ }
132
+ }
133
+ } else {
134
+ if key , ok := r .keysFromCache ()[keyID ]; ok {
135
+ if payload , err := jws .Verify (& key ); err == nil {
136
+ return payload , true
127
137
}
128
138
}
129
139
}
130
- return nil , errors . New ( "failed to verify id token signature" )
140
+ return nil , false
131
141
}
132
142
133
- func (r * RemoteKeySet ) keysFromCache () (keys [ ]jose.JSONWebKey ) {
143
+ func (r * RemoteKeySet ) keysFromCache () (keys map [ string ]jose.JSONWebKey ) {
134
144
r .mu .Lock ()
135
145
defer r .mu .Unlock ()
136
146
return r .cachedKeys
137
147
}
138
148
139
149
// keysFromRemote syncs the key set from the remote set, records the values in the
140
150
// cache, and returns the key set.
141
- func (r * RemoteKeySet ) keysFromRemote (ctx context.Context ) ([ ]jose.JSONWebKey , error ) {
151
+ func (r * RemoteKeySet ) keysFromRemote (ctx context.Context ) (map [ string ]jose.JSONWebKey , error ) {
142
152
// Need to lock to inspect the inflight request field.
143
153
r .mu .Lock ()
144
154
// If there's not a current inflight request, create one.
@@ -178,7 +188,7 @@ func (r *RemoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, e
178
188
}
179
189
}
180
190
181
- func (r * RemoteKeySet ) updateKeys () ([ ]jose.JSONWebKey , error ) {
191
+ func (r * RemoteKeySet ) updateKeys () (map [ string ]jose.JSONWebKey , error ) {
182
192
req , err := http .NewRequest ("GET" , r .jwksURL , nil )
183
193
if err != nil {
184
194
return nil , fmt .Errorf ("oidc: can't create request: %v" , err )
@@ -204,5 +214,9 @@ func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) {
204
214
if err != nil {
205
215
return nil , fmt .Errorf ("oidc: failed to decode keys: %v %s" , err , body )
206
216
}
207
- return keySet .Keys , nil
217
+ keys := make (map [string ]jose.JSONWebKey )
218
+ for _ , key := range keySet .Keys {
219
+ keys [key .KeyID ] = key
220
+ }
221
+ return keys , nil
208
222
}
0 commit comments