Skip to content

Commit c1d0eb1

Browse files
committed
fix: include client_assertion in token hook payload
1 parent 5d2ca41 commit c1d0eb1

File tree

2 files changed

+188
-1
lines changed

2 files changed

+188
-1
lines changed

oauth2/oauth2_jwt_bearer_test.go

+187
Original file line numberDiff line numberDiff line change
@@ -561,3 +561,190 @@ func TestJWTBearer(t *testing.T) {
561561
t.Run("strategy=jwt", run("jwt"))
562562
})
563563
}
564+
565+
func TestJWTClientAssertion(t *testing.T) {
566+
ctx := context.Background()
567+
568+
reg := testhelpers.NewMockedRegistry(t, &contextx.Default{})
569+
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
570+
_, admin := testhelpers.NewOAuth2Server(ctx, t, reg)
571+
572+
set, kid := uuid.NewString(), uuid.NewString()
573+
keys, err := jwk.GenerateJWK(ctx, jose.RS256, kid, "sig")
574+
require.NoError(t, err)
575+
signer := jwk.NewDefaultJWTSigner(reg.Config(), reg, set)
576+
signer.GetPrivateKey = func(ctx context.Context) (interface{}, error) {
577+
return keys.Keys[0], nil
578+
}
579+
580+
client := &hc.Client{
581+
GrantTypes: []string{"client_credentials"},
582+
Scope: "offline_access",
583+
TokenEndpointAuthMethod: "private_key_jwt",
584+
JSONWebKeys: &x.JoseJSONWebKeySet{
585+
JSONWebKeySet: &jose.JSONWebKeySet{
586+
Keys: []jose.JSONWebKey{keys.Keys[0].Public()},
587+
},
588+
},
589+
}
590+
require.NoError(t, reg.ClientManager().CreateClient(ctx, client))
591+
592+
var newConf = func(client *hc.Client) *clientcredentials.Config {
593+
return &clientcredentials.Config{
594+
AuthStyle: goauth2.AuthStyleInParams,
595+
TokenURL: reg.Config().OAuth2TokenURL(ctx).String(),
596+
Scopes: strings.Split(client.Scope, " "),
597+
EndpointParams: url.Values{
598+
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
599+
},
600+
}
601+
}
602+
var getToken = func(t *testing.T, conf *clientcredentials.Config) (*goauth2.Token, error) {
603+
return conf.Token(context.Background())
604+
}
605+
606+
var inspectToken = func(t *testing.T, token *goauth2.Token, cl *hc.Client, strategy string, checkExtraClaims bool) {
607+
introspection := testhelpers.IntrospectToken(t, &goauth2.Config{ClientID: cl.GetID(), ClientSecret: cl.Secret}, token.AccessToken, admin)
608+
609+
check := func(res gjson.Result) {
610+
assert.EqualValues(t, cl.GetID(), res.Get("client_id").String(), "%s", res.Raw)
611+
assert.EqualValues(t, cl.GetID(), res.Get("sub").String(), "%s", res.Raw)
612+
assert.EqualValues(t, reg.Config().IssuerURL(ctx).String(), res.Get("iss").String(), "%s", res.Raw)
613+
614+
assert.EqualValues(t, res.Get("nbf").Int(), res.Get("iat").Int(), "%s", res.Raw)
615+
assert.True(t, res.Get("exp").Int() >= res.Get("iat").Int()+int64(reg.Config().GetAccessTokenLifespan(ctx).Seconds()), "%s", res.Raw)
616+
617+
if checkExtraClaims {
618+
require.True(t, res.Get("ext.hooked").Bool())
619+
}
620+
}
621+
622+
check(introspection)
623+
assert.True(t, introspection.Get("active").Bool())
624+
assert.EqualValues(t, "access_token", introspection.Get("token_use").String())
625+
assert.EqualValues(t, "Bearer", introspection.Get("token_type").String())
626+
assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw)
627+
628+
if strategy != "jwt" {
629+
return
630+
}
631+
632+
body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1])
633+
require.NoError(t, err)
634+
jwtClaims := gjson.ParseBytes(body)
635+
assert.NotEmpty(t, jwtClaims.Get("jti").String())
636+
assert.NotEmpty(t, jwtClaims.Get("iss").String())
637+
assert.NotEmpty(t, jwtClaims.Get("client_id").String())
638+
assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw)
639+
640+
header, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[0])
641+
require.NoError(t, err)
642+
jwtHeader := gjson.ParseBytes(header)
643+
assert.NotEmpty(t, jwtHeader.Get("kid").String())
644+
assert.EqualValues(t, "offline_access", introspection.Get("scope").String(), "%s", introspection.Raw)
645+
646+
check(jwtClaims)
647+
}
648+
649+
var generateAssertion = func() (string, error) {
650+
token, _, err := signer.Generate(ctx, jwt.MapClaims{
651+
"jti": uuid.NewString(),
652+
"iss": client.GetID(),
653+
"sub": client.GetID(),
654+
"aud": reg.Config().OAuth2TokenURL(ctx).String(),
655+
"exp": time.Now().Add(time.Hour).Unix(),
656+
"iat": time.Now().Add(-time.Minute).Unix(),
657+
}, &jwt.Headers{Extra: map[string]interface{}{"kid": kid}})
658+
return token, err
659+
}
660+
661+
t.Run("case=unable to exchange invalid jwt", func(t *testing.T) {
662+
conf := newConf(client)
663+
conf.EndpointParams.Set("client_assertion", "not-a-jwt")
664+
_, err := getToken(t, conf)
665+
require.Error(t, err)
666+
assert.Contains(t, err.Error(), "Unable to verify the integrity of the 'client_assertion' value.")
667+
})
668+
669+
t.Run("case=should exchange for an access token", func(t *testing.T) {
670+
run := func(strategy string) func(t *testing.T) {
671+
return func(t *testing.T) {
672+
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
673+
674+
token, err := generateAssertion()
675+
require.NoError(t, err)
676+
677+
conf := newConf(client)
678+
conf.EndpointParams.Set("client_assertion", token)
679+
680+
result, err := getToken(t, conf)
681+
require.NoError(t, err)
682+
683+
inspectToken(t, result, client, strategy, false)
684+
}
685+
}
686+
687+
t.Run("strategy=opaque", run("opaque"))
688+
t.Run("strategy=jwt", run("jwt"))
689+
})
690+
691+
t.Run("should call token hook if configured", func(t *testing.T) {
692+
run := func(strategy string) func(t *testing.T) {
693+
return func(t *testing.T) {
694+
token, err := generateAssertion()
695+
require.NoError(t, err)
696+
697+
hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
698+
assert.Equal(t, r.Header.Get("Content-Type"), "application/json; charset=UTF-8")
699+
700+
expectedGrantedScopes := []string{client.Scope}
701+
expectedPayload := map[string][]string{
702+
"grant_type": {"client_credentials"},
703+
"client_assertion": {token},
704+
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
705+
"scope": {"offline_access"},
706+
}
707+
708+
var hookReq hydraoauth2.TokenHookRequest
709+
require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq))
710+
require.NotEmpty(t, hookReq.Session)
711+
require.Equal(t, hookReq.Session.Extra, map[string]interface{}{})
712+
require.NotEmpty(t, hookReq.Request)
713+
require.ElementsMatch(t, hookReq.Request.GrantedScopes, expectedGrantedScopes)
714+
require.Equal(t, expectedPayload, hookReq.Request.Payload)
715+
716+
claims := map[string]interface{}{
717+
"hooked": true,
718+
}
719+
720+
hookResp := hydraoauth2.TokenHookResponse{
721+
Session: flow.AcceptOAuth2ConsentRequestSession{
722+
AccessToken: claims,
723+
IDToken: claims,
724+
},
725+
}
726+
727+
w.WriteHeader(http.StatusOK)
728+
require.NoError(t, json.NewEncoder(w).Encode(&hookResp))
729+
}))
730+
defer hs.Close()
731+
732+
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, strategy)
733+
reg.Config().MustSet(ctx, config.KeyTokenHook, hs.URL)
734+
735+
defer reg.Config().MustSet(ctx, config.KeyTokenHook, nil)
736+
737+
conf := newConf(client)
738+
conf.EndpointParams.Set("client_assertion", token)
739+
740+
result, err := getToken(t, conf)
741+
require.NoError(t, err)
742+
743+
inspectToken(t, result, client, strategy, true)
744+
}
745+
}
746+
747+
t.Run("strategy=opaque", run("opaque"))
748+
t.Run("strategy=jwt", run("jwt"))
749+
})
750+
}

oauth2/token_hook.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ func TokenHook(reg interface {
176176
GrantedScopes: requester.GetGrantedScopes(),
177177
GrantedAudience: requester.GetGrantedAudience(),
178178
GrantTypes: requester.GetGrantTypes(),
179-
Payload: requester.Sanitize([]string{"assertion"}).GetRequestForm(),
179+
Payload: requester.Sanitize([]string{"assertion", "client_assertion_type", "client_assertion"}).GetRequestForm(),
180180
}
181181

182182
reqBody := TokenHookRequest{

0 commit comments

Comments
 (0)