Skip to content

Commit cca3e4e

Browse files
committed
Keep extra backend provided id and access token claims on refresh
When using a refresh token, any potentially encoded extra claims for id and access token must be retained. With this change, the claims embedded into the refresh token are applied on top of any other access or id token claims before creating the corresponding token. This avoids loosing any of the backend provided claims when refresh tokens are used.
1 parent 311421c commit cca3e4e

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

oidc/provider/handlers.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ func (p *Provider) AuthorizeResponse(rw http.ResponseWriter, req *http.Request,
239239

240240
// Create access token when requested.
241241
if _, ok := ar.ResponseTypes[oidc.ResponseTypeToken]; ok {
242-
accessTokenString, err = p.makeAccessToken(ctx, ar.ClientID, auth, nil)
242+
accessTokenString, err = p.makeAccessToken(ctx, ar.ClientID, auth, nil, nil)
243243
if err != nil {
244244
goto done
245245
}
@@ -248,7 +248,7 @@ func (p *Provider) AuthorizeResponse(rw http.ResponseWriter, req *http.Request,
248248
// Create ID token when requested and granted.
249249
if authorizedScopes[oidc.ScopeOpenID] {
250250
if _, ok := ar.ResponseTypes[oidc.ResponseTypeIDToken]; ok {
251-
idTokenString, err = p.makeIDToken(ctx, ar, auth, session, accessTokenString, codeString, nil)
251+
idTokenString, err = p.makeIDToken(ctx, ar, auth, session, accessTokenString, codeString, nil, nil)
252252
if err != nil {
253253
goto done
254254
}
@@ -330,6 +330,7 @@ func (p *Provider) TokenHandler(rw http.ResponseWriter, req *http.Request) {
330330
var accessTokenString string
331331
var idTokenString string
332332
var refreshTokenString string
333+
var refreshTokenClaims *konnect.RefreshTokenClaims
333334
var approvedScopes map[string]bool
334335
var authorizedScopes map[string]bool
335336
var clientDetails *clients.Details
@@ -498,13 +499,16 @@ func (p *Provider) TokenHandler(rw http.ResponseWriter, req *http.Request) {
498499
ClientID: claims.Audience,
499500
}
500501

502+
// Remember refresh token claims, for use in access and id token generators later on.
503+
refreshTokenClaims = claims
504+
501505
default:
502506
err = konnectoidc.NewOAuth2Error(oidc.ErrorCodeOAuth2UnsupportedGrantType, "grant_type value not implemented")
503507
goto done
504508
}
505509

506510
// Create access token.
507-
accessTokenString, err = p.makeAccessToken(ctx, ar.ClientID, auth, signinMethod)
511+
accessTokenString, err = p.makeAccessToken(ctx, ar.ClientID, auth, signinMethod, refreshTokenClaims)
508512
if err != nil {
509513
goto done
510514
}
@@ -513,7 +517,7 @@ func (p *Provider) TokenHandler(rw http.ResponseWriter, req *http.Request) {
513517
case oidc.GrantTypeAuthorizationCode, oidc.GrantTypeRefreshToken:
514518
// Create ID token when not previously requested amd openid scope is authorized.
515519
if !ar.ResponseTypes[oidc.ResponseTypeIDToken] && authorizedScopes[oidc.ScopeOpenID] {
516-
idTokenString, err = p.makeIDToken(ctx, ar, auth, session, accessTokenString, "", signinMethod)
520+
idTokenString, err = p.makeIDToken(ctx, ar, auth, session, accessTokenString, "", signinMethod, refreshTokenClaims)
517521
if err != nil {
518522
goto done
519523
}

oidc/provider/tokens.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ import (
3535

3636
// MakeAccessToken implements the oidc.AccessTokenProvider interface.
3737
func (p *Provider) MakeAccessToken(ctx context.Context, audience string, auth identity.AuthRecord) (string, error) {
38-
return p.makeAccessToken(ctx, audience, auth, nil)
38+
return p.makeAccessToken(ctx, audience, auth, nil, nil)
3939
}
4040

41-
func (p *Provider) makeAccessToken(ctx context.Context, audience string, auth identity.AuthRecord, signingMethod jwt.SigningMethod) (string, error) {
41+
func (p *Provider) makeAccessToken(ctx context.Context, audience string, auth identity.AuthRecord, signingMethod jwt.SigningMethod, refreshTokenClaims *konnect.RefreshTokenClaims) (string, error) {
4242
sk, ok := p.getSigningKey(signingMethod)
4343
if !ok {
4444
return "", fmt.Errorf("no signing key")
@@ -67,6 +67,17 @@ func (p *Provider) makeAccessToken(ctx context.Context, audience string, auth id
6767
accessTokenClaims.IdentityClaims = userWithClaims.Claims()
6868
}
6969
accessTokenClaims.IdentityProvider = auth.Manager().Name()
70+
if accessTokenClaims.IdentityClaims != nil && refreshTokenClaims != nil && refreshTokenClaims.IdentityClaims != nil {
71+
if refreshTokenClaims.IdentityProvider != accessTokenClaims.IdentityProvider {
72+
return "", fmt.Errorf("refresh token claims provider mismatch")
73+
}
74+
for k, v := range refreshTokenClaims.IdentityClaims {
75+
// Force to use refresh token identity claim values. This also locks all
76+
// the extra claims for id and access tokens to the ones provided from
77+
// the refresh token claims (which currently includes the session id).
78+
accessTokenClaims.IdentityClaims[k] = v
79+
}
80+
}
7081
}
7182

7283
// Support additional custom user specific claims.
@@ -113,7 +124,7 @@ func (p *Provider) makeAccessToken(ctx context.Context, audience string, auth id
113124
return accessToken.SignedString(sk.PrivateKey)
114125
}
115126

116-
func (p *Provider) makeIDToken(ctx context.Context, ar *payload.AuthenticationRequest, auth identity.AuthRecord, session *payload.Session, accessTokenString string, codeString string, signingMethod jwt.SigningMethod) (string, error) {
127+
func (p *Provider) makeIDToken(ctx context.Context, ar *payload.AuthenticationRequest, auth identity.AuthRecord, session *payload.Session, accessTokenString string, codeString string, signingMethod jwt.SigningMethod, refreshTokenClaims *konnect.RefreshTokenClaims) (string, error) {
117128
sk, ok := p.getSigningKey(signingMethod)
118129
if !ok {
119130
return "", fmt.Errorf("no signing key")
@@ -160,6 +171,18 @@ func (p *Provider) makeIDToken(ctx context.Context, ar *payload.AuthenticationRe
160171
if userWithClaims, ok := user.(identity.UserWithClaims); ok {
161172
accessTokenClaims.IdentityClaims = userWithClaims.Claims()
162173
}
174+
accessTokenClaims.IdentityProvider = auth.Manager().Name()
175+
if accessTokenClaims.IdentityClaims != nil && refreshTokenClaims != nil && refreshTokenClaims.IdentityClaims != nil {
176+
if refreshTokenClaims.IdentityProvider != accessTokenClaims.IdentityProvider {
177+
return "", fmt.Errorf("refresh token claims provider mismatch")
178+
}
179+
for k, v := range refreshTokenClaims.IdentityClaims {
180+
// Force to use refresh token identity claim values. This also locks all
181+
// the extra claims for id and access tokens to the ones provided from
182+
// the refresh token claims (which currently includes the session id).
183+
accessTokenClaims.IdentityClaims[k] = v
184+
}
185+
}
163186

164187
if withIDTokenClaimsRequest {
165188
// Apply additional information from ID token claims request.

0 commit comments

Comments
 (0)