Skip to content

fix: include client_assertion in token hook payload #3949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ replace github.com/ory/hydra-client-go/v2 => ./internal/httpclient

replace github.com/gobuffalo/pop/v6 => github.com/ory/pop/v6 v6.2.1-0.20241121111754-e5dfc0f3344b

replace github.com/ory/fosite => github.com/phooijenga/fosite v0.0.0-20250225211800-ea87e12044d7

require (
github.com/ThalesIgnite/crypto11 v1.2.5
github.com/bradleyjkemp/cupaloy/v2 v2.8.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,6 @@ github.com/ory/analytics-go/v5 v5.0.1 h1:LX8T5B9FN8KZXOtxgN+R3I4THRRVB6+28IKgKBp
github.com/ory/analytics-go/v5 v5.0.1/go.mod h1:lWCiCjAaJkKfgR/BN5DCLMol8BjKS1x+4jxBxff/FF0=
github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d h1:By96ZSVuH5LyjXLVVMfvJoLVGHaT96LdOnwgFSLVf0E=
github.com/ory/dockertest/v3 v3.10.1-0.20240704115616-d229e74b748d/go.mod h1:F2FIjwwAk6CsNAs//B8+aPFQF0t84pbM8oliyNXwQrk=
github.com/ory/fosite v0.49.0 h1:KNqO7RVt/1X8F08/UI0Y+GRvcpscCWgjqvpLBQPRovo=
github.com/ory/fosite v0.49.0/go.mod h1:FAn7IY+I6DjT1r29wMouPeRYq63DWUuBj++96uOS4mE=
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe h1:rvu4obdvqR0fkSIJ8IfgzKOWwZ5kOT2UNfLq81Qk7rc=
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe/go.mod h1:z4n3u6as84LbV4YmgjHhnwtccQqzf4cZlSk9f1FhygI=
github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTsTS8=
Expand All @@ -404,6 +402,8 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 h1:JhzVVoYvbOACxoUmOs6V/G4D5nPVUW73rKvXxP4XUJc=
github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE=
github.com/phooijenga/fosite v0.0.0-20250225211800-ea87e12044d7 h1:OowGroy4LX9hrZMRGncDq7g3e/rzezXZZlMkdkhOkaM=
github.com/phooijenga/fosite v0.0.0-20250225211800-ea87e12044d7/go.mod h1:FAn7IY+I6DjT1r29wMouPeRYq63DWUuBj++96uOS4mE=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
Expand Down
192 changes: 191 additions & 1 deletion oauth2/oauth2_jwt_bearer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,9 @@ func TestJWTBearer(t *testing.T) {
audience := reg.Config().OAuth2TokenURL(ctx).String()
grantType := "urn:ietf:params:oauth:grant-type:jwt-bearer"

jti := uuid.NewString()
token, _, err := signer.Generate(ctx, jwt.MapClaims{
"jti": uuid.NewString(),
"jti": jti,
"iss": trustGrant.Issuer,
"sub": trustGrant.Subject,
"aud": audience,
Expand Down Expand Up @@ -339,6 +340,7 @@ func TestJWTBearer(t *testing.T) {
require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes)
require.ElementsMatch(t, hookReq.Request.GrantedAudience, expectedGrantedAudience)
require.Equal(t, expectedPayload, hookReq.Request.Payload)
require.Equal(t, jti, hookReq.Request.JWTClaims["jti"])

claims := map[string]interface{}{
"hooked": true,
Expand Down Expand Up @@ -561,3 +563,191 @@ func TestJWTBearer(t *testing.T) {
t.Run("strategy=jwt", run("jwt"))
})
}

func TestJWTClientAssertion(t *testing.T) {
ctx := context.Background()

reg := testhelpers.NewMockedRegistry(t, &contextx.Default{})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
_, admin := testhelpers.NewOAuth2Server(ctx, t, reg)

set, kid := uuid.NewString(), uuid.NewString()
keys, err := jwk.GenerateJWK(ctx, jose.RS256, kid, "sig")
require.NoError(t, err)
signer := jwk.NewDefaultJWTSigner(reg.Config(), reg, set)
signer.GetPrivateKey = func(ctx context.Context) (interface{}, error) {
return keys.Keys[0], nil
}

client := &hc.Client{
GrantTypes: []string{"client_credentials"},
Scope: "offline_access",
TokenEndpointAuthMethod: "private_key_jwt",
JSONWebKeys: &x.JoseJSONWebKeySet{
JSONWebKeySet: &jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{keys.Keys[0].Public()},
},
},
}
require.NoError(t, reg.ClientManager().CreateClient(ctx, client))

var newConf = func(client *hc.Client) *clientcredentials.Config {
return &clientcredentials.Config{
AuthStyle: goauth2.AuthStyleInParams,
TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
Scopes: strings.Split(client.Scope, " "),
EndpointParams: url.Values{
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
},
}
}
var getToken = func(t *testing.T, conf *clientcredentials.Config) (*goauth2.Token, error) {
return conf.Token(context.Background())
}

var inspectToken = func(t *testing.T, token *goauth2.Token, cl *hc.Client, strategy string, checkExtraClaims bool) {
introspection := testhelpers.IntrospectToken(t, &goauth2.Config{ClientID: cl.GetID(), ClientSecret: cl.Secret}, token.AccessToken, admin)

check := func(res gjson.Result) {
assert.EqualValues(t, cl.GetID(), res.Get("client_id").String(), "%s", res.Raw)
assert.EqualValues(t, cl.GetID(), res.Get("sub").String(), "%s", res.Raw)
assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), res.Get("iss").String(), "%s", res.Raw)

assert.EqualValues(t, res.Get("nbf").Int(), res.Get("iat").Int(), "%s", res.Raw)
assert.True(t, res.Get("exp").Int() >= res.Get("iat").Int()+int64(reg.Config().GetAccessTokenLifespan(ctx).Seconds()), "%s", res.Raw)

if checkExtraClaims {
require.True(t, res.Get("ext.hooked").Bool())
}
}

check(introspection)
assert.True(t, introspection.Get("active").Bool())
assert.EqualValues(t, "access_token", introspection.Get("token_use").String())
assert.EqualValues(t, "Bearer", introspection.Get("token_type").String())
assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw)

if strategy != "jwt" {
return
}

body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1])
require.NoError(t, err)
jwtClaims := gjson.ParseBytes(body)
assert.NotEmpty(t, jwtClaims.Get("jti").String())
assert.NotEmpty(t, jwtClaims.Get("iss").String())
assert.NotEmpty(t, jwtClaims.Get("client_id").String())
assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw)

header, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[0])
require.NoError(t, err)
jwtHeader := gjson.ParseBytes(header)
assert.NotEmpty(t, jwtHeader.Get("kid").String())
assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw)

check(jwtClaims)
}

var generateAssertion = func() (string, jwt.MapClaims, error) {
claims := jwt.MapClaims{
"jti": uuid.NewString(),
"iss": client.GetID(),
"sub": client.GetID(),
"aud": reg.Config().OAuth2TokenURL(ctx).String(),
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Add(-time.Minute).Unix(),
}
headers := &jwt.Headers{Extra: map[string]interface{}{"kid": kid}}
token, _, err := signer.Generate(ctx, claims, headers)
return token, claims, err
}

t.Run("case=unable to exchange invalid jwt", func(t *testing.T) {
conf := newConf(client)
conf.EndpointParams.Set("client_assertion", "not-a-jwt")
_, err := getToken(t, conf)
require.Error(t, err)
assert.Contains(t, err.Error(), "Unable to verify the integrity of the 'client_assertion' value.")
})

t.Run("case=should exchange for an access token", func(t *testing.T) {
run := func(strategy string) func(t *testing.T) {
return func(t *testing.T) {
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)

token, _, err := generateAssertion()
require.NoError(t, err)

conf := newConf(client)
conf.EndpointParams.Set("client_assertion", token)

result, err := getToken(t, conf)
require.NoError(t, err)

inspectToken(t, result, client, strategy, false)
}
}

t.Run("strategy=opaque", run("opaque"))
t.Run("strategy=jwt", run("jwt"))
})

t.Run("should call token hook if configured", func(t *testing.T) {
run := func(strategy string) func(t *testing.T) {
return func(t *testing.T) {
token, assertionClaims, err := generateAssertion()
require.NoError(t, err)

hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")

expectedGrantedScopes := []string{client.Scope}
expectedPayload := map[string][]string{
"grant_type": {"client_credentials"},
"scope": {"offline_access"},
}

var hookReq hydraoauth2.TokenHookRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq))
require.NotEmpty(t, hookReq.Session)
require.Equal(t, hookReq.Session.Extra, map[string]interface{}{})
require.NotEmpty(t, hookReq.Request)
require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes)
require.Equal(t, expectedPayload, hookReq.Request.Payload)
require.Equal(t, assertionClaims["jti"], hookReq.Request.JWTClaims["jti"])

claims := map[string]interface{}{
"hooked": true,
}

hookResp := hydraoauth2.TokenHookResponse{
Session: flow.AcceptOAuth2ConsentRequestSession{
AccessToken: claims,
IDToken: claims,
},
}

w.WriteHeader(http.StatusOK)
require.NoError(t, json.NewEncoder(w).Encode(&hookResp))
}))
defer hs.Close()

reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)

defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)

conf := newConf(client)
conf.EndpointParams.Set("client_assertion", token)

result, err := getToken(t, conf)
require.NoError(t, err)

inspectToken(t, result, client, strategy, true)
}
}

t.Run("strategy=opaque", run("opaque"))
t.Run("strategy=jwt", run("jwt"))
})
}
3 changes: 3 additions & 0 deletions oauth2/token_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type Request struct {
GrantTypes []string `json:"grant_types"`
// Payload is the requests payload.
Payload map[string][]string `json:"payload"`
// JWTClaims contains the decoded JWT claims (RFC 7523).
JWTClaims map[string]interface{} `json:"jwt_claims"`
}

// TokenHookRequest is the request body sent to the token hook.
Expand Down Expand Up @@ -177,6 +179,7 @@ func TokenHook(reg interface {
GrantedAudience: requester.GetGrantedAudience(),
GrantTypes: requester.GetGrantTypes(),
Payload: requester.Sanitize([]string{"assertion"}).GetRequestForm(),
JWTClaims: requester.GetJWTClaims(),
}

reqBody := TokenHookRequest{
Expand Down
Loading