Skip to content

Commit c18b10a

Browse files
committed
feat(custom-oauth): add per-provider custom_claims_allowlist
1 parent 3dacc64 commit c18b10a

9 files changed

Lines changed: 558 additions & 62 deletions

internal/api/custom_oauth_admin.go

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ type AdminCustomOAuthProviderParams struct {
5353
Scopes []string `json:"scopes"`
5454
PKCEEnabled *bool `json:"pkce_enabled,omitempty"`
5555
AttributeMapping map[string]interface{} `json:"attribute_mapping,omitempty"`
56-
AuthorizationParams map[string]interface{} `json:"authorization_params,omitempty"`
57-
Enabled *bool `json:"enabled,omitempty"`
58-
EmailOptional *bool `json:"email_optional,omitempty"`
56+
// CustomClaimsAllowlist lists raw IdP claim keys to copy verbatim into custom_claims.
57+
CustomClaimsAllowlist []string `json:"custom_claims_allowlist,omitempty"`
58+
AuthorizationParams map[string]interface{} `json:"authorization_params,omitempty"`
59+
Enabled *bool `json:"enabled,omitempty"`
60+
EmailOptional *bool `json:"email_optional,omitempty"`
5961

6062
// OIDC-specific fields
6163
Issuer string `json:"issuer,omitempty"`
@@ -170,6 +172,11 @@ func (a *API) adminCustomOAuthProviderCreate(w http.ResponseWriter, r *http.Requ
170172
return err
171173
}
172174

175+
// Validate custom claims allowlist (non-empty source keys)
176+
if err := validateCustomClaimsAllowlist(params.CustomClaimsAllowlist); err != nil {
177+
return err
178+
}
179+
173180
// Check quota if configured
174181
if config.CustomOAuth.MaxProviders > 0 {
175182
totalCount, err := models.CountCustomOAuthProviders(db)
@@ -278,6 +285,13 @@ func (a *API) adminCustomOAuthProviderUpdate(w http.ResponseWriter, r *http.Requ
278285
}
279286
}
280287

288+
// Validate custom claims allowlist if provided
289+
if params.CustomClaimsAllowlist != nil {
290+
if err := validateCustomClaimsAllowlist(params.CustomClaimsAllowlist); err != nil {
291+
return err
292+
}
293+
}
294+
281295
// Read the existing provider outside the write transaction so the
282296
// network call (discovery fetch) doesn't hold a transaction open.
283297
provider, err := models.FindCustomOAuthProviderByIdentifier(db, identifier)
@@ -477,18 +491,19 @@ func buildProviderFromParams(params *AdminCustomOAuthProviderParams, providerTyp
477491
// Generate ID upfront so it's available for client secret encryption (used as AAD)
478492
id, _ := uuid.NewV4()
479493
provider := &models.CustomOAuthProvider{
480-
ID: id,
481-
ProviderType: providerType,
482-
Identifier: params.Identifier,
483-
Name: params.Name,
484-
ClientID: params.ClientID,
485-
AcceptableClientIDs: popslices.String(params.AcceptableClientIDs),
486-
Scopes: popslices.String(params.Scopes),
487-
PKCEEnabled: getBoolOrDefault(params.PKCEEnabled, true),
488-
AttributeMapping: popslices.Map(params.AttributeMapping),
489-
AuthorizationParams: popslices.Map(params.AuthorizationParams),
490-
Enabled: getBoolOrDefault(params.Enabled, true),
491-
EmailOptional: getBoolOrDefault(params.EmailOptional, false),
494+
ID: id,
495+
ProviderType: providerType,
496+
Identifier: params.Identifier,
497+
Name: params.Name,
498+
ClientID: params.ClientID,
499+
AcceptableClientIDs: popslices.String(params.AcceptableClientIDs),
500+
Scopes: popslices.String(params.Scopes),
501+
PKCEEnabled: getBoolOrDefault(params.PKCEEnabled, true),
502+
AttributeMapping: popslices.Map(params.AttributeMapping),
503+
CustomClaimsAllowlist: popslices.String(params.CustomClaimsAllowlist),
504+
AuthorizationParams: popslices.Map(params.AuthorizationParams),
505+
Enabled: getBoolOrDefault(params.Enabled, true),
506+
EmailOptional: getBoolOrDefault(params.EmailOptional, false),
492507
}
493508

494509
// Set type-specific fields
@@ -560,6 +575,9 @@ func updateProviderFromParams(provider *models.CustomOAuthProvider, params *Admi
560575
if params.AttributeMapping != nil {
561576
provider.AttributeMapping = popslices.Map(params.AttributeMapping)
562577
}
578+
if params.CustomClaimsAllowlist != nil {
579+
provider.CustomClaimsAllowlist = popslices.String(params.CustomClaimsAllowlist)
580+
}
563581
if params.AuthorizationParams != nil {
564582
provider.AuthorizationParams = popslices.Map(params.AuthorizationParams)
565583
}
@@ -749,3 +767,18 @@ func validateAttributeMapping(mapping map[string]interface{}) error {
749767
return nil
750768
}
751769

770+
// validateCustomClaimsAllowlist ensures every allowlist entry is a non-empty
771+
// source claim key. Unlike attribute_mapping, these are opaque source keys
772+
// copied into custom_claims (not typed targets)
773+
func validateCustomClaimsAllowlist(allowlist []string) error {
774+
for _, key := range allowlist {
775+
if strings.TrimSpace(key) == "" {
776+
return apierrors.NewBadRequestError(
777+
apierrors.ErrorCodeValidationFailed,
778+
"custom_claims_allowlist entries must be non-empty strings",
779+
)
780+
}
781+
}
782+
783+
return nil
784+
}

internal/api/custom_oauth_admin_test.go

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import (
1010
"testing"
1111

1212
popslices "github.com/gobuffalo/pop/v6/slices"
13-
jwt "github.com/golang-jwt/jwt/v5"
1413
"github.com/gofrs/uuid"
14+
jwt "github.com/golang-jwt/jwt/v5"
1515
"github.com/stretchr/testify/assert"
1616
"github.com/stretchr/testify/require"
1717
"github.com/stretchr/testify/suite"
@@ -596,6 +596,56 @@ func (ts *CustomOAuthAdminTestSuite) TestUpdateProvider() {
596596
assert.Equal(ts.T(), popslices.String{"openid", "profile", "email"}, updated.Scopes)
597597
}
598598

599+
func (ts *CustomOAuthAdminTestSuite) TestCustomClaimsAllowlistCreateUpdateGet() {
600+
payload := ts.createTestOAuth2Payload("allowlist-provider")
601+
payload["custom_claims_allowlist"] = []string{"groups", "org_id"}
602+
603+
w := ts.createProvider(payload, http.StatusCreated)
604+
605+
var created models.CustomOAuthProvider
606+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&created))
607+
assert.Equal(ts.T(), popslices.String{"groups", "org_id"}, created.CustomClaimsAllowlist)
608+
609+
// PUT replaces the allowlist.
610+
updatePayload := map[string]interface{}{
611+
"custom_claims_allowlist": []string{"mail", "sn", "nlEduPersonProfileId"},
612+
}
613+
var body bytes.Buffer
614+
require.NoError(ts.T(), json.NewEncoder(&body).Encode(updatePayload))
615+
616+
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/custom-providers/%s", created.Identifier), &body)
617+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))
618+
w = httptest.NewRecorder()
619+
ts.API.handler.ServeHTTP(w, req)
620+
require.Equal(ts.T(), http.StatusOK, w.Code)
621+
622+
var updated models.CustomOAuthProvider
623+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&updated))
624+
assert.Equal(ts.T(), popslices.String{"mail", "sn", "nlEduPersonProfileId"}, updated.CustomClaimsAllowlist)
625+
626+
// GET returns the persisted allowlist.
627+
req = httptest.NewRequest(http.MethodGet, fmt.Sprintf("/admin/custom-providers/%s", created.Identifier), nil)
628+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))
629+
w = httptest.NewRecorder()
630+
ts.API.handler.ServeHTTP(w, req)
631+
require.Equal(ts.T(), http.StatusOK, w.Code)
632+
633+
var fetched models.CustomOAuthProvider
634+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&fetched))
635+
assert.Equal(ts.T(), popslices.String{"mail", "sn", "nlEduPersonProfileId"}, fetched.CustomClaimsAllowlist)
636+
}
637+
638+
func (ts *CustomOAuthAdminTestSuite) TestCustomClaimsAllowlistRejectsEmptyEntries() {
639+
payload := ts.createTestOAuth2Payload("allowlist-bad")
640+
payload["custom_claims_allowlist"] = []string{"groups", ""}
641+
642+
w := ts.createProvider(payload, http.StatusBadRequest)
643+
644+
var apiErr apierrors.HTTPError
645+
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&apiErr))
646+
assert.Equal(ts.T(), apierrors.ErrorCodeValidationFailed, apiErr.ErrorCode)
647+
}
648+
599649
// Test DELETE /admin/custom-providers/:id (Delete)
600650

601651
func (ts *CustomOAuthAdminTestSuite) TestDeleteProvider() {

internal/api/external.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,7 @@ func (a *API) loadCustomProvider(ctx context.Context, db *storage.Connection, id
743743
customProvider.AcceptableClientIDs,
744744
customProvider.AttributeMapping,
745745
customProvider.AuthorizationParams,
746+
customProvider.CustomClaimsAllowlist,
746747
)
747748

748749
// Build provider configuration
@@ -776,6 +777,7 @@ func (a *API) loadCustomProvider(ctx context.Context, db *storage.Connection, id
776777
customProvider.AcceptableClientIDs,
777778
customProvider.AttributeMapping,
778779
customProvider.AuthorizationParams,
780+
customProvider.CustomClaimsAllowlist,
779781
a.oidcCache,
780782
)
781783
if err != nil {

internal/api/provider/custom_oauth.go

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ import (
1212

1313
// CustomOAuthProvider implements OAuthProvider for custom OAuth2 providers
1414
type CustomOAuthProvider struct {
15-
config *oauth2.Config
16-
userinfoURL string
17-
pkceEnabled bool
18-
acceptableClientIDs []string
19-
attributeMapping map[string]interface{}
20-
authorizationParams map[string]interface{}
15+
config *oauth2.Config
16+
userinfoURL string
17+
pkceEnabled bool
18+
acceptableClientIDs []string
19+
attributeMapping map[string]interface{}
20+
authorizationParams map[string]interface{}
21+
customClaimsAllowlist []string
2122
}
2223

2324
// NewCustomOAuthProvider creates a new custom OAuth provider
@@ -27,6 +28,7 @@ func NewCustomOAuthProvider(
2728
pkceEnabled bool,
2829
acceptableClientIDs []string,
2930
attributeMapping, authorizationParams map[string]interface{},
31+
customClaimsAllowlist []string,
3032
) *CustomOAuthProvider {
3133
config := &oauth2.Config{
3234
ClientID: clientID,
@@ -40,12 +42,13 @@ func NewCustomOAuthProvider(
4042
}
4143

4244
return &CustomOAuthProvider{
43-
config: config,
44-
userinfoURL: userinfoURL,
45-
pkceEnabled: pkceEnabled,
46-
acceptableClientIDs: acceptableClientIDs,
47-
attributeMapping: attributeMapping,
48-
authorizationParams: authorizationParams,
45+
config: config,
46+
userinfoURL: userinfoURL,
47+
pkceEnabled: pkceEnabled,
48+
acceptableClientIDs: acceptableClientIDs,
49+
attributeMapping: attributeMapping,
50+
authorizationParams: authorizationParams,
51+
customClaimsAllowlist: customClaimsAllowlist,
4952
}
5053
}
5154

@@ -68,11 +71,14 @@ func (p *CustomOAuthProvider) GetOAuthToken(ctx context.Context, code string, op
6871

6972
// GetUserData fetches user data from the provider's userinfo endpoint
7073
func (p *CustomOAuthProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
71-
var claims Claims
72-
if err := makeRequest(ctx, tok, p.config, p.userinfoURL, &claims); err != nil {
74+
claims, raw, err := fetchUserinfoClaims(ctx, tok, p.config, p.userinfoURL)
75+
if err != nil {
7376
return nil, err
7477
}
7578

79+
// Capture allowlisted custom claims before attribute mapping
80+
captureAllowedClaims(raw, p.customClaimsAllowlist, &claims)
81+
7682
// Apply attribute mapping if configured
7783
if len(p.attributeMapping) > 0 {
7884
claims = applyAttributeMapping(claims, p.attributeMapping)
@@ -101,13 +107,14 @@ func (p *CustomOAuthProvider) RequiresPKCE() bool {
101107

102108
// CustomOIDCProvider implements OAuthProvider for custom OIDC providers
103109
type CustomOIDCProvider struct {
104-
config *oauth2.Config
105-
oidcProvider *oidc.Provider
106-
userinfoEndpoint string
107-
pkceEnabled bool
108-
acceptableClientIDs []string
109-
attributeMapping map[string]interface{}
110-
authorizationParams map[string]interface{}
110+
config *oauth2.Config
111+
oidcProvider *oidc.Provider
112+
userinfoEndpoint string
113+
pkceEnabled bool
114+
acceptableClientIDs []string
115+
attributeMapping map[string]interface{}
116+
authorizationParams map[string]interface{}
117+
customClaimsAllowlist []string
111118
}
112119

113120
// NewCustomOIDCProvider creates a new custom OIDC provider.
@@ -123,6 +130,7 @@ func NewCustomOIDCProvider(
123130
pkceEnabled bool,
124131
acceptableClientIDs []string,
125132
attributeMapping, authorizationParams map[string]interface{},
133+
customClaimsAllowlist []string,
126134
cache *OIDCProviderCache,
127135
) (*CustomOIDCProvider, error) {
128136
// Ensure 'openid' scope is always present for OIDC
@@ -155,13 +163,14 @@ func NewCustomOIDCProvider(
155163
}
156164

157165
return &CustomOIDCProvider{
158-
config: config,
159-
oidcProvider: oidcProvider,
160-
userinfoEndpoint: userinfoEndpoint,
161-
pkceEnabled: pkceEnabled,
162-
acceptableClientIDs: acceptableClientIDs,
163-
attributeMapping: attributeMapping,
164-
authorizationParams: authorizationParams,
166+
config: config,
167+
oidcProvider: oidcProvider,
168+
userinfoEndpoint: userinfoEndpoint,
169+
pkceEnabled: pkceEnabled,
170+
acceptableClientIDs: acceptableClientIDs,
171+
attributeMapping: attributeMapping,
172+
authorizationParams: authorizationParams,
173+
customClaimsAllowlist: customClaimsAllowlist,
165174
}, nil
166175
}
167176

@@ -201,6 +210,17 @@ func (p *CustomOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token)
201210
return nil, err
202211
}
203212

213+
// Capture allowlisted custom claims from the raw ID token before
214+
// attribute mapping. Because we only copy explicitly listed keys, there
215+
// is no risk of re-adding keys a parser intentionally stripped (e.g. Azure).
216+
if len(p.customClaimsAllowlist) > 0 && userData.Metadata != nil {
217+
var raw map[string]interface{}
218+
if err := idTokenObj.Claims(&raw); err != nil {
219+
return nil, fmt.Errorf("failed to read ID token claims: %w", err)
220+
}
221+
captureAllowedClaims(raw, p.customClaimsAllowlist, userData.Metadata)
222+
}
223+
204224
// Apply attribute mapping to the metadata from ID token
205225
if len(p.attributeMapping) > 0 && userData.Metadata != nil {
206226
*userData.Metadata = applyAttributeMapping(*userData.Metadata, p.attributeMapping)
@@ -211,11 +231,14 @@ func (p *CustomOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token)
211231

212232
// No ID token, use userinfo endpoint
213233
if p.userinfoEndpoint != "" {
214-
var claims Claims
215-
if err := makeRequest(ctx, tok, p.config, p.userinfoEndpoint, &claims); err != nil {
234+
claims, raw, err := fetchUserinfoClaims(ctx, tok, p.config, p.userinfoEndpoint)
235+
if err != nil {
216236
return nil, err
217237
}
218238

239+
// Capture allowlisted custom claims before attribute mapping
240+
captureAllowedClaims(raw, p.customClaimsAllowlist, &claims)
241+
219242
// Apply attribute mapping
220243
if len(p.attributeMapping) > 0 {
221244
claims = applyAttributeMapping(claims, p.attributeMapping)
@@ -268,6 +291,44 @@ func (p *CustomOIDCProvider) validateAudience(audiences []string) error {
268291
return fmt.Errorf("token audience %v does not match any acceptable client ID", audiences)
269292
}
270293

294+
// fetchUserinfoClaims fetches the userinfo response once and returns both the
295+
// typed Claims and the raw claim map. The raw map is needed so that arbitrary
296+
// allowlisted keys (which have no typed field) can be copied verbatim.
297+
func fetchUserinfoClaims(ctx context.Context, tok *oauth2.Token, config *oauth2.Config, url string) (Claims, map[string]interface{}, error) {
298+
var raw map[string]interface{}
299+
if err := makeRequest(ctx, tok, config, url, &raw); err != nil {
300+
return Claims{}, nil, err
301+
}
302+
303+
var claims Claims
304+
b, err := json.Marshal(raw)
305+
if err != nil {
306+
return Claims{}, nil, err
307+
}
308+
if err := json.Unmarshal(b, &claims); err != nil {
309+
return Claims{}, nil, err
310+
}
311+
312+
return claims, raw, nil
313+
}
314+
315+
// captureAllowedClaims copies each allowlisted key present in raw into
316+
// c.CustomClaims verbatim. An empty allowlist captures nothing (D1), and keys
317+
// absent from raw are silently skipped (no nil entry is created). Because only
318+
// explicitly listed keys are copied, protocol/registered claims never leak.
319+
func captureAllowedClaims(raw map[string]interface{}, allowlist []string, c *Claims) {
320+
for _, key := range allowlist {
321+
value, ok := raw[key]
322+
if !ok {
323+
continue
324+
}
325+
if c.CustomClaims == nil {
326+
c.CustomClaims = make(map[string]interface{})
327+
}
328+
c.CustomClaims[key] = value
329+
}
330+
}
331+
271332
// applyAttributeMapping applies custom attribute mapping to claims
272333
func applyAttributeMapping(claims Claims, mapping map[string]interface{}) Claims {
273334
// Create a map representation of claims for easier manipulation

0 commit comments

Comments
 (0)