From 81c5802db3af9ed4bb32a5151820fbdd10d5af6b Mon Sep 17 00:00:00 2001 From: Paul Hooijenga Date: Thu, 20 Feb 2025 10:07:47 +0100 Subject: [PATCH] fix: include JWT claims in token hook payload --- go.mod | 2 + go.sum | 4 +- oauth2/oauth2_jwt_bearer_test.go | 192 ++++++++++++++++++++++++++++++- oauth2/token_hook.go | 3 + 4 files changed, 198 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 2a9413c0ea5..9ae6baa0275 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,8 @@ replace github.com/ory/hydra-client-go/v2 => ./internal/httpclient replace github.com/gobuffalo/pop/v6 => github.com/ory/pop/v6 v6.2.1-0.20241121111754-e5dfc0f3344b +replace github.com/ory/fosite => github.com/phooijenga/fosite v0.0.0-20250225211800-ea87e12044d7 + require ( github.com/ThalesIgnite/crypto11 v1.2.5 github.com/bradleyjkemp/cupaloy/v2 v2.8.0 diff --git a/go.sum b/go.sum index f621e7822b7..d1a8701980c 100644 --- a/go.sum +++ b/go.sum @@ -378,8 +378,6 @@ github.com/ory/analytics-go/v5 v5.0.1 h1:LX8T5B9FN8KZXOtxgN+R3I4THRRVB6+28IKgKBp github.com/ory/analytics-go/v5 v5.0.1/go.mod h1:lWCiCjAaJkKfgR/BN5DCLMol8BjKS1x+4jxBxff/FF0= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d h1:By96ZSVuH5LyjXLVVMfvJoLVGHaT96LdOnwgFSLVf0E= github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d/go.mod h1:F2FIjwwAk6CsNAs//B8+aPFQF0t84pbM8oliyNXwQrk= -github.com/ory/fosite v0.49.0 h1:KNqO7RVt/1X8F08/UI0Y+GRvcpscCWgjqvpLBQPRovo= -github.com/ory/fosite v0.49.0/go.mod h1:FAn7IY+I6DjT1r29wMouPeRYq63DWUuBj++96uOS4mE= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc= github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI= github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8= @@ -404,6 +402,8 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6 github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 h1:JhzVVoYvbOACxoUmOs6V/G4D5nPVUW73rKvXxP4XUJc= github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= +github.com/phooijenga/fosite v0.0.0-20250225211800-ea87e12044d7 h1:OowGroy4LX9hrZMRGncDq7g3e/rzezXZZlMkdkhOkaM= +github.com/phooijenga/fosite v0.0.0-20250225211800-ea87e12044d7/go.mod h1:FAn7IY+I6DjT1r29wMouPeRYq63DWUuBj++96uOS4mE= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/oauth2/oauth2_jwt_bearer_test.go b/oauth2/oauth2_jwt_bearer_test.go index 0b1a862ba05..ec2d6fd848a 100644 --- a/oauth2/oauth2_jwt_bearer_test.go +++ b/oauth2/oauth2_jwt_bearer_test.go @@ -310,8 +310,9 @@ func TestJWTBearer(t *testing.T) { audience := reg.Config().OAuth2TokenURL(ctx).String() grantType := "urn:ietf:params:oauth:grant-type:jwt-bearer" + jti := uuid.NewString() token, _, err := signer.Generate(ctx, jwt.MapClaims{ - "jti": uuid.NewString(), + "jti": jti, "iss": trustGrant.Issuer, "sub": trustGrant.Subject, "aud": audience, @@ -339,6 +340,7 @@ func TestJWTBearer(t *testing.T) { require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) require.ElementsMatch(t, hookReq.Request.GrantedAudience, expectedGrantedAudience) require.Equal(t, expectedPayload, hookReq.Request.Payload) + require.Equal(t, jti, hookReq.Request.JWTClaims["jti"]) claims := map[string]interface{}{ "hooked": true, @@ -561,3 +563,191 @@ func TestJWTBearer(t *testing.T) { t.Run("strategy=jwt", run("jwt")) }) } + +func TestJWTClientAssertion(t *testing.T) { + ctx := context.Background() + + reg := testhelpers.NewMockedRegistry(t, &contextx.Default{}) + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque") + _, admin := testhelpers.NewOAuth2Server(ctx, t, reg) + + set, kid := uuid.NewString(), uuid.NewString() + keys, err := jwk.GenerateJWK(ctx, jose.RS256, kid, "sig") + require.NoError(t, err) + signer := jwk.NewDefaultJWTSigner(reg.Config(), reg, set) + signer.GetPrivateKey = func(ctx context.Context) (interface{}, error) { + return keys.Keys[0], nil + } + + client := &hc.Client{ + GrantTypes: []string{"client_credentials"}, + Scope: "offline_access", + TokenEndpointAuthMethod: "private_key_jwt", + JSONWebKeys: &x.JoseJSONWebKeySet{ + JSONWebKeySet: &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{keys.Keys[0].Public()}, + }, + }, + } + require.NoError(t, reg.ClientManager().CreateClient(ctx, client)) + + var newConf = func(client *hc.Client) *clientcredentials.Config { + return &clientcredentials.Config{ + AuthStyle: goauth2.AuthStyleInParams, + TokenURL: reg.Config().OAuth2TokenURL(ctx).String(), + Scopes: strings.Split(client.Scope, " "), + EndpointParams: url.Values{ + "client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"}, + }, + } + } + var getToken = func(t *testing.T, conf *clientcredentials.Config) (*goauth2.Token, error) { + return conf.Token(context.Background()) + } + + var inspectToken = func(t *testing.T, token *goauth2.Token, cl *hc.Client, strategy string, checkExtraClaims bool) { + introspection := testhelpers.IntrospectToken(t, &goauth2.Config{ClientID: cl.GetID(), ClientSecret: cl.Secret}, token.AccessToken, admin) + + check := func(res gjson.Result) { + assert.EqualValues(t, cl.GetID(), res.Get("client_id").String(), "%s", res.Raw) + assert.EqualValues(t, cl.GetID(), res.Get("sub").String(), "%s", res.Raw) + assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), res.Get("iss").String(), "%s", res.Raw) + + assert.EqualValues(t, res.Get("nbf").Int(), res.Get("iat").Int(), "%s", res.Raw) + assert.True(t, res.Get("exp").Int() >= res.Get("iat").Int()+int64(reg.Config().GetAccessTokenLifespan(ctx).Seconds()), "%s", res.Raw) + + if checkExtraClaims { + require.True(t, res.Get("ext.hooked").Bool()) + } + } + + check(introspection) + assert.True(t, introspection.Get("active").Bool()) + assert.EqualValues(t, "access_token", introspection.Get("token_use").String()) + assert.EqualValues(t, "Bearer", introspection.Get("token_type").String()) + assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw) + + if strategy != "jwt" { + return + } + + body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1]) + require.NoError(t, err) + jwtClaims := gjson.ParseBytes(body) + assert.NotEmpty(t, jwtClaims.Get("jti").String()) + assert.NotEmpty(t, jwtClaims.Get("iss").String()) + assert.NotEmpty(t, jwtClaims.Get("client_id").String()) + assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw) + + header, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[0]) + require.NoError(t, err) + jwtHeader := gjson.ParseBytes(header) + assert.NotEmpty(t, jwtHeader.Get("kid").String()) + assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw) + + check(jwtClaims) + } + + var generateAssertion = func() (string, jwt.MapClaims, error) { + claims := jwt.MapClaims{ + "jti": uuid.NewString(), + "iss": client.GetID(), + "sub": client.GetID(), + "aud": reg.Config().OAuth2TokenURL(ctx).String(), + "exp": time.Now().Add(time.Hour).Unix(), + "iat": time.Now().Add(-time.Minute).Unix(), + } + headers := &jwt.Headers{Extra: map[string]interface{}{"kid": kid}} + token, _, err := signer.Generate(ctx, claims, headers) + return token, claims, err + } + + t.Run("case=unable to exchange invalid jwt", func(t *testing.T) { + conf := newConf(client) + conf.EndpointParams.Set("client_assertion", "not-a-jwt") + _, err := getToken(t, conf) + require.Error(t, err) + assert.Contains(t, err.Error(), "Unable to verify the integrity of the 'client_assertion' value.") + }) + + t.Run("case=should exchange for an access token", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + + token, _, err := generateAssertion() + require.NoError(t, err) + + conf := newConf(client) + conf.EndpointParams.Set("client_assertion", token) + + result, err := getToken(t, conf) + require.NoError(t, err) + + inspectToken(t, result, client, strategy, false) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) + + t.Run("should call token hook if configured", func(t *testing.T) { + run := func(strategy string) func(t *testing.T) { + return func(t *testing.T) { + token, assertionClaims, err := generateAssertion() + require.NoError(t, err) + + hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8") + + expectedGrantedScopes := []string{client.Scope} + expectedPayload := map[string][]string{ + "grant_type": {"client_credentials"}, + "scope": {"offline_access"}, + } + + var hookReq hydraoauth2.TokenHookRequest + require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq)) + require.NotEmpty(t, hookReq.Session) + require.Equal(t, hookReq.Session.Extra, map[string]interface{}{}) + require.NotEmpty(t, hookReq.Request) + require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes) + require.Equal(t, expectedPayload, hookReq.Request.Payload) + require.Equal(t, assertionClaims["jti"], hookReq.Request.JWTClaims["jti"]) + + claims := map[string]interface{}{ + "hooked": true, + } + + hookResp := hydraoauth2.TokenHookResponse{ + Session: flow.AcceptOAuth2ConsentRequestSession{ + AccessToken: claims, + IDToken: claims, + }, + } + + w.WriteHeader(http.StatusOK) + require.NoError(t, json.NewEncoder(w).Encode(&hookResp)) + })) + defer hs.Close() + + reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy) + reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL) + + defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil) + + conf := newConf(client) + conf.EndpointParams.Set("client_assertion", token) + + result, err := getToken(t, conf) + require.NoError(t, err) + + inspectToken(t, result, client, strategy, true) + } + } + + t.Run("strategy=opaque", run("opaque")) + t.Run("strategy=jwt", run("jwt")) + }) +} diff --git a/oauth2/token_hook.go b/oauth2/token_hook.go index 04122dd2ef9..4632da8a1f6 100644 --- a/oauth2/token_hook.go +++ b/oauth2/token_hook.go @@ -41,6 +41,8 @@ type Request struct { GrantTypes []string `json:"grant_types"` // Payload is the requests payload. Payload map[string][]string `json:"payload"` + // JWTClaims contains the decoded JWT claims (RFC 7523). + JWTClaims map[string]interface{} `json:"jwt_claims"` } // TokenHookRequest is the request body sent to the token hook. @@ -177,6 +179,7 @@ func TokenHook(reg interface { GrantedAudience: requester.GetGrantedAudience(), GrantTypes: requester.GetGrantTypes(), Payload: requester.Sanitize([]string{"assertion"}).GetRequestForm(), + JWTClaims: requester.GetJWTClaims(), } reqBody := TokenHookRequest{