Skip to content

Commit 607da43

Browse files
authored
feat: allow amr claim to be array of strings or objects (#2274)
## Summary This PR loosens the validation for the `amr` (Authentication Method Reference) claim in custom access token hooks to accept both array of strings and array of objects, instead of only array of objects. - **Test Coverage**: Added two new test cases: - Modify amr to be array of strings - Verifies that `amr` as an array of strings passes validation - Modify amr to be array of objects - Verifies that `amr` as an array of objects still works (backward compatibility) ## Motivation This change provides more flexibility for custom access token hooks. [RFC-8176 ](https://www.rfc-editor.org/rfc/rfc8176.html#section-1)requires `amr` to be array of strings. ## Testing All tests pass, including: - Existing custom access token tests - New test for `amr` as array of strings - New test for `amr` as array of objects ## Backward Compatibility **Fully backward compatible** - The change only adds support for an additional format. Existing hooks that return `amr` as an array of objects will continue to work without any changes.
1 parent d9de0af commit 607da43

4 files changed

Lines changed: 235 additions & 2 deletions

File tree

internal/api/e2e_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,90 @@ func TestE2EHooks(t *testing.T) {
971971
}
972972
})
973973
}
974+
975+
t.Run("AMRStringArrayUnmarshalling", func(t *testing.T) {
976+
defer inst.HookRecorder.CustomizeAccessToken.ClearCalls()
977+
978+
// Setup hook that returns amr as array of strings
979+
var claimsIn M
980+
hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
981+
w.Header().Add("content-type", "application/json")
982+
w.WriteHeader(http.StatusOK)
983+
984+
err := json.NewDecoder(r.Body).Decode(&claimsIn)
985+
require.NoError(t, err)
986+
987+
// Modify amr to be array of strings instead of objects
988+
claimsOut := copyMap(t, claimsIn)
989+
claimsOut["claims"].(M)["amr"] = []string{"password", "totp"}
990+
991+
err = json.NewEncoder(w).Encode(claimsOut)
992+
require.NoError(t, err)
993+
})
994+
995+
inst.HookRecorder.CustomizeAccessToken.ClearCalls()
996+
inst.HookRecorder.CustomizeAccessToken.SetHandler(hr)
997+
998+
// Get token with modified amr
999+
req := &api.PasswordGrantParams{
1000+
Email: string(currentUser.Email),
1001+
Password: defaultPassword,
1002+
}
1003+
1004+
res := new(api.AccessTokenResponse)
1005+
err := e2eapi.Do(ctx, http.MethodPost, inst.APIServer.URL+"/token?grant_type=password", req, res)
1006+
require.NoError(t, err)
1007+
require.True(t, len(res.Token) > 0)
1008+
1009+
// Verify hook was called
1010+
{
1011+
calls := inst.HookRecorder.CustomizeAccessToken.GetCalls()
1012+
require.Equal(t, 1, len(calls))
1013+
}
1014+
1015+
// Parse token to verify it can be unmarshalled
1016+
p := jwt.NewParser(jwt.WithValidMethods(globalCfg.JWT.ValidMethods))
1017+
token, err := p.ParseWithClaims(
1018+
res.Token,
1019+
&api.AccessTokenClaims{},
1020+
func(token *jwt.Token) (any, error) {
1021+
if kid, ok := token.Header["kid"]; ok {
1022+
if kidStr, ok := kid.(string); ok {
1023+
return conf.FindPublicKeyByKid(kidStr, &globalCfg.JWT)
1024+
}
1025+
}
1026+
if alg, ok := token.Header["alg"]; ok {
1027+
if alg == jwt.SigningMethodHS256.Name {
1028+
return []byte(globalCfg.JWT.Secret), nil
1029+
}
1030+
}
1031+
return nil, fmt.Errorf("missing kid")
1032+
})
1033+
require.NoError(t, err, "Token should parse successfully even with string array amr")
1034+
1035+
fmt.Println("token hereee", res.Token)
1036+
// Verify claims were unmarshalled correctly
1037+
claims, ok := token.Claims.(*api.AccessTokenClaims)
1038+
require.True(t, ok, "Claims should be AccessTokenClaims type")
1039+
require.NotNil(t, claims.AuthenticationMethodReference, "AMR should not be nil")
1040+
require.Len(t, claims.AuthenticationMethodReference, 2, "AMR should have 2 entries")
1041+
require.Equal(t, "password", claims.AuthenticationMethodReference[0].Method)
1042+
require.Equal(t, "totp", claims.AuthenticationMethodReference[1].Method)
1043+
1044+
// Call /user endpoint with the token to verify it works end-to-end
1045+
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, "/user", nil)
1046+
require.NoError(t, err)
1047+
1048+
httpRes, err := inst.DoAuth(httpReq, res.Token)
1049+
require.NoError(t, err, "Should be able to call /user endpoint with token containing string array amr")
1050+
require.Equal(t, http.StatusOK, httpRes.StatusCode, "/user endpoint should return 200 OK")
1051+
1052+
// Verify we got user data back
1053+
var userData models.User
1054+
err = json.NewDecoder(httpRes.Body).Decode(&userData)
1055+
require.NoError(t, err, "Should be able to decode user response")
1056+
require.Equal(t, currentUser.ID, userData.ID, "Should get the correct user")
1057+
})
9741058
})
9751059

9761060
t.Run("SendEmail", func(t *testing.T) {

internal/api/token_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,43 @@ end; $$ language plpgsql;`,
741741
"user_metadata": nil,
742742
},
743743
shouldError: false,
744+
}, {
745+
desc: "Modify amr to be array of strings",
746+
uri: "pg-functions://postgres/auth/custom_access_token_amr_strings",
747+
hookFunctionSQL: `
748+
create or replace function custom_access_token_amr_strings(input jsonb)
749+
returns jsonb as $$
750+
declare
751+
result jsonb;
752+
begin
753+
input := jsonb_set(input, '{claims,amr}', '["password", "mfa"]'::jsonb);
754+
result := jsonb_build_object('claims', input->'claims');
755+
return result;
756+
end; $$ language plpgsql;`,
757+
expectedClaims: map[string]interface{}{
758+
"amr": []interface{}{"password", "mfa"},
759+
},
760+
shouldError: false,
761+
}, {
762+
desc: "Modify amr to be array of objects",
763+
uri: "pg-functions://postgres/auth/custom_access_token_amr_objects",
764+
hookFunctionSQL: `
765+
create or replace function custom_access_token_amr_objects(input jsonb)
766+
returns jsonb as $$
767+
declare
768+
result jsonb;
769+
begin
770+
input := jsonb_set(input, '{claims,amr}', '[{"method": "password"}, {"method": "mfa"}]'::jsonb);
771+
result := jsonb_build_object('claims', input->'claims');
772+
return result;
773+
end; $$ language plpgsql;`,
774+
expectedClaims: map[string]interface{}{
775+
"amr": []interface{}{
776+
map[string]interface{}{"method": "password"},
777+
map[string]interface{}{"method": "mfa"},
778+
},
779+
},
780+
shouldError: false,
744781
},
745782
}
746783
for _, c := range cases {

internal/tokens/service.go

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tokens
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
mathRand "math/rand"
78
"net/http"
@@ -26,6 +27,47 @@ import (
2627

2728
const retryLoopDuration = 5.0
2829

30+
// AMRClaim supports unmarshalling AMR as either strings or AMREntry objects.
31+
type AMRClaim []models.AMREntry
32+
33+
// UnmarshalJSON accepts either an array of strings or AMREntry objects.
34+
func (a *AMRClaim) UnmarshalJSON(data []byte) error {
35+
// Handle null explicitly - null cannot be unmarshaled into a slice
36+
if len(data) > 0 {
37+
trimmed := strings.TrimSpace(string(data))
38+
if trimmed == "null" {
39+
*a = AMRClaim{}
40+
return nil
41+
}
42+
}
43+
44+
var rawItems []json.RawMessage
45+
if err := json.Unmarshal(data, &rawItems); err != nil {
46+
return err
47+
}
48+
49+
entries := make([]models.AMREntry, 0, len(rawItems))
50+
for _, item := range rawItems {
51+
var method string
52+
if err := json.Unmarshal(item, &method); err == nil {
53+
entries = append(entries, models.AMREntry{
54+
Method: method,
55+
Timestamp: time.Now().Unix(),
56+
})
57+
continue
58+
}
59+
60+
var entry models.AMREntry
61+
if err := json.Unmarshal(item, &entry); err != nil {
62+
return err
63+
}
64+
entries = append(entries, entry)
65+
}
66+
67+
*a = entries
68+
return nil
69+
}
70+
2971
// AccessTokenClaims is a struct thats used for JWT claims
3072
type AccessTokenClaims struct {
3173
jwt.RegisteredClaims
@@ -35,7 +77,7 @@ type AccessTokenClaims struct {
3577
UserMetaData map[string]interface{} `json:"user_metadata"`
3678
Role string `json:"role"`
3779
AuthenticatorAssuranceLevel string `json:"aal,omitempty"`
38-
AuthenticationMethodReference []models.AMREntry `json:"amr,omitempty"`
80+
AuthenticationMethodReference AMRClaim `json:"amr,omitempty"`
3981
SessionId string `json:"session_id,omitempty"`
4082
IsAnonymous bool `json:"is_anonymous"`
4183
ClientID string `json:"client_id,omitempty"`
@@ -951,7 +993,10 @@ const MinimumViableTokenSchema = `{
951993
"amr": {
952994
"type": "array",
953995
"items": {
954-
"type": "object"
996+
"anyOf": [
997+
{"type": "string"},
998+
{"type": "object"}
999+
]
9551000
}
9561001
},
9571002
"session_id": {

internal/tokens/service_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/rand"
66
"encoding/base64"
7+
"encoding/json"
78
"net/http"
89
"strconv"
910
"strings"
@@ -1022,3 +1023,69 @@ func (ts *IDTokenTestSuite) TestIDTokenWithMultipleScopes() {
10221023
phoneNumber, hasPhone := claims["phone_number"]
10231024
require.False(ts.T(), hasPhone || (phoneNumber != nil && phoneNumber != ""), "phone_number claim should not be present without phone scope")
10241025
}
1026+
1027+
func TestAMRClaimUnmarshal(t *testing.T) {
1028+
t.Run("mixed string and object formats", func(t *testing.T) {
1029+
var claim AMRClaim
1030+
before := time.Now().Unix()
1031+
1032+
err := json.Unmarshal([]byte(`["password", {"method":"totp","timestamp":123,"provider":"webauthn"}]`), &claim)
1033+
require.NoError(t, err)
1034+
require.Len(t, claim, 2)
1035+
1036+
require.Equal(t, "password", claim[0].Method)
1037+
require.GreaterOrEqual(t, claim[0].Timestamp, before)
1038+
require.LessOrEqual(t, claim[0].Timestamp, time.Now().Unix())
1039+
require.Empty(t, claim[0].Provider, "string format should not have provider")
1040+
1041+
require.Equal(t, "totp", claim[1].Method)
1042+
require.Equal(t, int64(123), claim[1].Timestamp)
1043+
require.Equal(t, "webauthn", claim[1].Provider, "provider should be preserved from object format")
1044+
})
1045+
1046+
t.Run("object with provider", func(t *testing.T) {
1047+
var claim AMRClaim
1048+
err := json.Unmarshal([]byte(`[{"method":"sso","timestamp":456,"provider":"saml"}]`), &claim)
1049+
require.NoError(t, err)
1050+
require.Len(t, claim, 1)
1051+
require.Equal(t, "sso", claim[0].Method)
1052+
require.Equal(t, int64(456), claim[0].Timestamp)
1053+
require.Equal(t, "saml", claim[0].Provider, "provider should be preserved")
1054+
})
1055+
1056+
t.Run("object without provider", func(t *testing.T) {
1057+
var claim AMRClaim
1058+
err := json.Unmarshal([]byte(`[{"method":"password","timestamp":789}]`), &claim)
1059+
require.NoError(t, err)
1060+
require.Len(t, claim, 1)
1061+
require.Equal(t, "password", claim[0].Method)
1062+
require.Equal(t, int64(789), claim[0].Timestamp)
1063+
require.Empty(t, claim[0].Provider, "provider should be empty when not provided")
1064+
})
1065+
1066+
t.Run("all strings", func(t *testing.T) {
1067+
var claim AMRClaim
1068+
before := time.Now().Unix()
1069+
err := json.Unmarshal([]byte(`["password", "totp"]`), &claim)
1070+
require.NoError(t, err)
1071+
require.Len(t, claim, 2)
1072+
require.Equal(t, "password", claim[0].Method)
1073+
require.Equal(t, "totp", claim[1].Method)
1074+
require.GreaterOrEqual(t, claim[0].Timestamp, before)
1075+
require.Empty(t, claim[0].Provider)
1076+
require.Empty(t, claim[1].Provider)
1077+
})
1078+
1079+
t.Run("all objects", func(t *testing.T) {
1080+
var claim AMRClaim
1081+
err := json.Unmarshal([]byte(`[{"method":"password","timestamp":100},{"method":"totp","timestamp":200,"provider":"webauthn"}]`), &claim)
1082+
require.NoError(t, err)
1083+
require.Len(t, claim, 2)
1084+
require.Equal(t, "password", claim[0].Method)
1085+
require.Equal(t, int64(100), claim[0].Timestamp)
1086+
require.Empty(t, claim[0].Provider)
1087+
require.Equal(t, "totp", claim[1].Method)
1088+
require.Equal(t, int64(200), claim[1].Timestamp)
1089+
require.Equal(t, "webauthn", claim[1].Provider, "provider should be preserved")
1090+
})
1091+
}

0 commit comments

Comments
 (0)