Skip to content

Commit 788ce10

Browse files
committed
feat(custom-oauth): add per-provider custom_claims_allowlist
1 parent 169ad67 commit 788ce10

9 files changed

Lines changed: 557 additions & 61 deletions

internal/api/custom_oauth_admin.go

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,11 @@ type AdminCustomOAuthProviderParams struct {
5656
Scopes []string `json:"scopes"`
5757
PKCEEnabled *bool `json:"pkce_enabled,omitempty"`
5858
AttributeMapping map[string]interface{} `json:"attribute_mapping,omitempty"`
59-
AuthorizationParams map[string]interface{} `json:"authorization_params,omitempty"`
60-
Enabled *bool `json:"enabled,omitempty"`
61-
EmailOptional *bool `json:"email_optional,omitempty"`
59+
// CustomClaimsAllowlist lists raw IdP claim keys to copy verbatim into custom_claims.
60+
CustomClaimsAllowlist []string `json:"custom_claims_allowlist,omitempty"`
61+
AuthorizationParams map[string]interface{} `json:"authorization_params,omitempty"`
62+
Enabled *bool `json:"enabled,omitempty"`
63+
EmailOptional *bool `json:"email_optional,omitempty"`
6264

6365
// OIDC-specific fields
6466
Issuer string `json:"issuer,omitempty"`
@@ -173,6 +175,11 @@ func (a *API) adminCustomOAuthProviderCreate(w http.ResponseWriter, r *http.Requ
173175
return err
174176
}
175177

178+
// Validate custom claims allowlist (non-empty source keys)
179+
if err := validateCustomClaimsAllowlist(params.CustomClaimsAllowlist); err != nil {
180+
return err
181+
}
182+
176183
// Check quota if configured
177184
if config.CustomOAuth.MaxProviders > 0 {
178185
totalCount, err := models.CountCustomOAuthProviders(db)
@@ -281,6 +288,13 @@ func (a *API) adminCustomOAuthProviderUpdate(w http.ResponseWriter, r *http.Requ
281288
}
282289
}
283290

291+
// Validate custom claims allowlist if provided
292+
if params.CustomClaimsAllowlist != nil {
293+
if err := validateCustomClaimsAllowlist(params.CustomClaimsAllowlist); err != nil {
294+
return err
295+
}
296+
}
297+
284298
// Read the existing provider outside the write transaction so the
285299
// network call (discovery fetch) doesn't hold a transaction open.
286300
provider, err := models.FindCustomOAuthProviderByIdentifier(db, identifier)
@@ -480,18 +494,19 @@ func buildProviderFromParams(params *AdminCustomOAuthProviderParams, providerTyp
480494
// Generate ID upfront so it's available for client secret encryption (used as AAD)
481495
id, _ := uuid.NewV4()
482496
provider := &models.CustomOAuthProvider{
483-
ID: id,
484-
ProviderType: providerType,
485-
Identifier: params.Identifier,
486-
Name: params.Name,
487-
ClientID: params.ClientID,
488-
AcceptableClientIDs: popslices.String(params.AcceptableClientIDs),
489-
Scopes: popslices.String(params.Scopes),
490-
PKCEEnabled: getBoolOrDefault(params.PKCEEnabled, true),
491-
AttributeMapping: popslices.Map(params.AttributeMapping),
492-
AuthorizationParams: popslices.Map(params.AuthorizationParams),
493-
Enabled: getBoolOrDefault(params.Enabled, true),
494-
EmailOptional: getBoolOrDefault(params.EmailOptional, false),
497+
ID: id,
498+
ProviderType: providerType,
499+
Identifier: params.Identifier,
500+
Name: params.Name,
501+
ClientID: params.ClientID,
502+
AcceptableClientIDs: popslices.String(params.AcceptableClientIDs),
503+
Scopes: popslices.String(params.Scopes),
504+
PKCEEnabled: getBoolOrDefault(params.PKCEEnabled, true),
505+
AttributeMapping: popslices.Map(params.AttributeMapping),
506+
CustomClaimsAllowlist: popslices.String(params.CustomClaimsAllowlist),
507+
AuthorizationParams: popslices.Map(params.AuthorizationParams),
508+
Enabled: getBoolOrDefault(params.Enabled, true),
509+
EmailOptional: getBoolOrDefault(params.EmailOptional, false),
495510
}
496511

497512
// Set type-specific fields
@@ -563,6 +578,9 @@ func updateProviderFromParams(provider *models.CustomOAuthProvider, params *Admi
563578
if params.AttributeMapping != nil {
564579
provider.AttributeMapping = popslices.Map(params.AttributeMapping)
565580
}
581+
if params.CustomClaimsAllowlist != nil {
582+
provider.CustomClaimsAllowlist = popslices.String(params.CustomClaimsAllowlist)
583+
}
566584
if params.AuthorizationParams != nil {
567585
provider.AuthorizationParams = popslices.Map(params.AuthorizationParams)
568586
}
@@ -778,3 +796,18 @@ func validateAttributeMapping(mapping map[string]interface{}) error {
778796
return nil
779797
}
780798

799+
// validateCustomClaimsAllowlist ensures every allowlist entry is a non-empty
800+
// source claim key. Unlike attribute_mapping, these are opaque source keys
801+
// copied into custom_claims (not typed targets)
802+
func validateCustomClaimsAllowlist(allowlist []string) error {
803+
for _, key := range allowlist {
804+
if strings.TrimSpace(key) == "" {
805+
return apierrors.NewBadRequestError(
806+
apierrors.ErrorCodeValidationFailed,
807+
"custom_claims_allowlist entries must be non-empty strings",
808+
)
809+
}
810+
}
811+
812+
return nil
813+
}

internal/api/custom_oauth_admin_test.go

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import (
1010
"testing"
1111

1212
popslices "github.com/gobuffalo/pop/v6/slices"
13-
jwt "github.com/golang-jwt/jwt/v5"
1413
"github.com/gofrs/uuid"
14+
jwt "github.com/golang-jwt/jwt/v5"
1515
"github.com/stretchr/testify/assert"
1616
"github.com/stretchr/testify/require"
1717
"github.com/stretchr/testify/suite"
@@ -596,6 +596,56 @@ func (ts *CustomOAuthAdminTestSuite) TestUpdateProvider() {
596596
assert.Equal(ts.T(), popslices.String{"openid", "profile", "email"}, updated.Scopes)
597597
}
598598

599+
func (ts *CustomOAuthAdminTestSuite) TestCustomClaimsAllowlistCreateUpdateGet() {
600+
payload := ts.createTestOAuth2Payload("allowlist-provider")
601+
payload["custom_claims_allowlist"] = []string{"groups", "org_id"}
602+
603+
w := ts.createProvider(payload, http.StatusCreated)
604+
605+
var created models.CustomOAuthProvider
606+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&created))
607+
assert.Equal(ts.T(), popslices.String{"groups", "org_id"}, created.CustomClaimsAllowlist)
608+
609+
// PUT replaces the allowlist.
610+
updatePayload := map[string]interface{}{
611+
"custom_claims_allowlist": []string{"mail", "sn", "nlEduPersonProfileId"},
612+
}
613+
var body bytes.Buffer
614+
require.NoError(ts.T(), json.NewEncoder(&body).Encode(updatePayload))
615+
616+
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/custom-providers/%s", created.Identifier), &body)
617+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))
618+
w = httptest.NewRecorder()
619+
ts.API.handler.ServeHTTP(w, req)
620+
require.Equal(ts.T(), http.StatusOK, w.Code)
621+
622+
var updated models.CustomOAuthProvider
623+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&updated))
624+
assert.Equal(ts.T(), popslices.String{"mail", "sn", "nlEduPersonProfileId"}, updated.CustomClaimsAllowlist)
625+
626+
// GET returns the persisted allowlist.
627+
req = httptest.NewRequest(http.MethodGet, fmt.Sprintf("/admin/custom-providers/%s", created.Identifier), nil)
628+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))
629+
w = httptest.NewRecorder()
630+
ts.API.handler.ServeHTTP(w, req)
631+
require.Equal(ts.T(), http.StatusOK, w.Code)
632+
633+
var fetched models.CustomOAuthProvider
634+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&fetched))
635+
assert.Equal(ts.T(), popslices.String{"mail", "sn", "nlEduPersonProfileId"}, fetched.CustomClaimsAllowlist)
636+
}
637+
638+
func (ts *CustomOAuthAdminTestSuite) TestCustomClaimsAllowlistRejectsEmptyEntries() {
639+
payload := ts.createTestOAuth2Payload("allowlist-bad")
640+
payload["custom_claims_allowlist"] = []string{"groups", ""}
641+
642+
w := ts.createProvider(payload, http.StatusBadRequest)
643+
644+
var apiErr apierrors.HTTPError
645+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&apiErr))
646+
assert.Equal(ts.T(), apierrors.ErrorCodeValidationFailed, apiErr.ErrorCode)
647+
}
648+
599649
// Test DELETE /admin/custom-providers/:id (Delete)
600650

601651
func (ts *CustomOAuthAdminTestSuite) TestDeleteProvider() {

internal/api/external.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ func (a *API) loadCustomProvider(ctx context.Context, db *storage.Connection, id
744744
customProvider.AcceptableClientIDs,
745745
customProvider.AttributeMapping,
746746
customProvider.AuthorizationParams,
747+
customProvider.CustomClaimsAllowlist,
747748
)
748749

749750
// Build provider configuration
@@ -777,6 +778,7 @@ func (a *API) loadCustomProvider(ctx context.Context, db *storage.Connection, id
777778
customProvider.AcceptableClientIDs,
778779
customProvider.AttributeMapping,
779780
customProvider.AuthorizationParams,
781+
customProvider.CustomClaimsAllowlist,
780782
a.oidcCache,
781783
)
782784
if err != nil {

internal/api/provider/custom_oauth.go

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ import (
1212

1313
// CustomOAuthProvider implements OAuthProvider for custom OAuth2 providers
1414
type CustomOAuthProvider struct {
15-
config *oauth2.Config
16-
userinfoURL string
17-
pkceEnabled bool
18-
acceptableClientIDs []string
19-
attributeMapping map[string]interface{}
20-
authorizationParams map[string]interface{}
15+
config *oauth2.Config
16+
userinfoURL string
17+
pkceEnabled bool
18+
acceptableClientIDs []string
19+
attributeMapping map[string]interface{}
20+
authorizationParams map[string]interface{}
21+
customClaimsAllowlist []string
2122
}
2223

2324
// NewCustomOAuthProvider creates a new custom OAuth provider
@@ -27,6 +28,7 @@ func NewCustomOAuthProvider(
2728
pkceEnabled bool,
2829
acceptableClientIDs []string,
2930
attributeMapping, authorizationParams map[string]interface{},
31+
customClaimsAllowlist []string,
3032
) *CustomOAuthProvider {
3133
config := &oauth2.Config{
3234
ClientID: clientID,
@@ -40,12 +42,13 @@ func NewCustomOAuthProvider(
4042
}
4143

4244
return &CustomOAuthProvider{
43-
config: config,
44-
userinfoURL: userinfoURL,
45-
pkceEnabled: pkceEnabled,
46-
acceptableClientIDs: acceptableClientIDs,
47-
attributeMapping: attributeMapping,
48-
authorizationParams: authorizationParams,
45+
config: config,
46+
userinfoURL: userinfoURL,
47+
pkceEnabled: pkceEnabled,
48+
acceptableClientIDs: acceptableClientIDs,
49+
attributeMapping: attributeMapping,
50+
authorizationParams: authorizationParams,
51+
customClaimsAllowlist: customClaimsAllowlist,
4952
}
5053
}
5154

@@ -68,11 +71,14 @@ func (p *CustomOAuthProvider) GetOAuthToken(ctx context.Context, code string, op
6871

6972
// GetUserData fetches user data from the provider's userinfo endpoint
7073
func (p *CustomOAuthProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
71-
var claims Claims
72-
if err := makeRequest(ctx, tok, p.config, p.userinfoURL, &claims); err != nil {
74+
claims, raw, err := fetchUserinfoClaims(ctx, tok, p.config, p.userinfoURL)
75+
if err != nil {
7376
return nil, err
7477
}
7578

79+
// Capture allowlisted custom claims before attribute mapping
80+
captureAllowedClaims(raw, p.customClaimsAllowlist, &claims)
81+
7682
// Apply attribute mapping if configured
7783
if len(p.attributeMapping) > 0 {
7884
claims = applyAttributeMapping(claims, p.attributeMapping)
@@ -101,13 +107,14 @@ func (p *CustomOAuthProvider) RequiresPKCE() bool {
101107

102108
// CustomOIDCProvider implements OAuthProvider for custom OIDC providers
103109
type CustomOIDCProvider struct {
104-
config *oauth2.Config
105-
oidcProvider *oidc.Provider
106-
userinfoEndpoint string
107-
pkceEnabled bool
108-
acceptableClientIDs []string
109-
attributeMapping map[string]interface{}
110-
authorizationParams map[string]interface{}
110+
config *oauth2.Config
111+
oidcProvider *oidc.Provider
112+
userinfoEndpoint string
113+
pkceEnabled bool
114+
acceptableClientIDs []string
115+
attributeMapping map[string]interface{}
116+
authorizationParams map[string]interface{}
117+
customClaimsAllowlist []string
111118
}
112119

113120
// NewCustomOIDCProvider creates a new custom OIDC provider
@@ -119,6 +126,7 @@ func NewCustomOIDCProvider(
119126
pkceEnabled bool,
120127
acceptableClientIDs []string,
121128
attributeMapping, authorizationParams map[string]interface{},
129+
customClaimsAllowlist []string,
122130
cache *OIDCProviderCache,
123131
) (*CustomOIDCProvider, error) {
124132
// Ensure 'openid' scope is always present for OIDC
@@ -152,13 +160,14 @@ func NewCustomOIDCProvider(
152160
}
153161

154162
return &CustomOIDCProvider{
155-
config: config,
156-
oidcProvider: oidcProvider,
157-
userinfoEndpoint: userinfoEndpoint,
158-
pkceEnabled: pkceEnabled,
159-
acceptableClientIDs: acceptableClientIDs,
160-
attributeMapping: attributeMapping,
161-
authorizationParams: authorizationParams,
163+
config: config,
164+
oidcProvider: oidcProvider,
165+
userinfoEndpoint: userinfoEndpoint,
166+
pkceEnabled: pkceEnabled,
167+
acceptableClientIDs: acceptableClientIDs,
168+
attributeMapping: attributeMapping,
169+
authorizationParams: authorizationParams,
170+
customClaimsAllowlist: customClaimsAllowlist,
162171
}, nil
163172
}
164173

@@ -198,6 +207,17 @@ func (p *CustomOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token)
198207
return nil, err
199208
}
200209

210+
// Capture allowlisted custom claims from the raw ID token before
211+
// attribute mapping. Because we only copy explicitly listed keys, there
212+
// is no risk of re-adding keys a parser intentionally stripped (e.g. Azure).
213+
if len(p.customClaimsAllowlist) > 0 && userData.Metadata != nil {
214+
var raw map[string]interface{}
215+
if err := idTokenObj.Claims(&raw); err != nil {
216+
return nil, fmt.Errorf("failed to read ID token claims: %w", err)
217+
}
218+
captureAllowedClaims(raw, p.customClaimsAllowlist, userData.Metadata)
219+
}
220+
201221
// Apply attribute mapping to the metadata from ID token
202222
if len(p.attributeMapping) > 0 && userData.Metadata != nil {
203223
*userData.Metadata = applyAttributeMapping(*userData.Metadata, p.attributeMapping)
@@ -208,11 +228,14 @@ func (p *CustomOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token)
208228

209229
// No ID token, use userinfo endpoint
210230
if p.userinfoEndpoint != "" {
211-
var claims Claims
212-
if err := makeRequest(ctx, tok, p.config, p.userinfoEndpoint, &claims); err != nil {
231+
claims, raw, err := fetchUserinfoClaims(ctx, tok, p.config, p.userinfoEndpoint)
232+
if err != nil {
213233
return nil, err
214234
}
215235

236+
// Capture allowlisted custom claims before attribute mapping
237+
captureAllowedClaims(raw, p.customClaimsAllowlist, &claims)
238+
216239
// Apply attribute mapping
217240
if len(p.attributeMapping) > 0 {
218241
claims = applyAttributeMapping(claims, p.attributeMapping)
@@ -265,6 +288,44 @@ func (p *CustomOIDCProvider) validateAudience(audiences []string) error {
265288
return fmt.Errorf("token audience %v does not match any acceptable client ID", audiences)
266289
}
267290

291+
// fetchUserinfoClaims fetches the userinfo response once and returns both the
292+
// typed Claims and the raw claim map. The raw map is needed so that arbitrary
293+
// allowlisted keys (which have no typed field) can be copied verbatim.
294+
func fetchUserinfoClaims(ctx context.Context, tok *oauth2.Token, config *oauth2.Config, url string) (Claims, map[string]interface{}, error) {
295+
var raw map[string]interface{}
296+
if err := makeRequest(ctx, tok, config, url, &raw); err != nil {
297+
return Claims{}, nil, err
298+
}
299+
300+
var claims Claims
301+
b, err := json.Marshal(raw)
302+
if err != nil {
303+
return Claims{}, nil, err
304+
}
305+
if err := json.Unmarshal(b, &claims); err != nil {
306+
return Claims{}, nil, err
307+
}
308+
309+
return claims, raw, nil
310+
}
311+
312+
// captureAllowedClaims copies each allowlisted key present in raw into
313+
// c.CustomClaims verbatim. An empty allowlist captures nothing (D1), and keys
314+
// absent from raw are silently skipped (no nil entry is created). Because only
315+
// explicitly listed keys are copied, protocol/registered claims never leak.
316+
func captureAllowedClaims(raw map[string]interface{}, allowlist []string, c *Claims) {
317+
for _, key := range allowlist {
318+
value, ok := raw[key]
319+
if !ok {
320+
continue
321+
}
322+
if c.CustomClaims == nil {
323+
c.CustomClaims = make(map[string]interface{})
324+
}
325+
c.CustomClaims[key] = value
326+
}
327+
}
328+
268329
// applyAttributeMapping applies custom attribute mapping to claims
269330
func applyAttributeMapping(claims Claims, mapping map[string]interface{}) Claims {
270331
// Create a map representation of claims for easier manipulation

0 commit comments

Comments
 (0)