Skip to content

Commit bb46397

Browse files
committed
feat(custom-oauth): preserve non-standard IdP claims in identity_data
1 parent be317c1 commit bb46397

2 files changed

Lines changed: 219 additions & 4 deletions

File tree

internal/api/provider/custom_oauth.go

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,72 @@ import (
55
"encoding/json"
66
"errors"
77
"fmt"
8+
"reflect"
9+
"strings"
810

911
"github.com/coreos/go-oidc/v3/oidc"
1012
"golang.org/x/oauth2"
1113
)
1214

15+
// standardClaimKeys is the set of JSON keys handled by the typed Claims
16+
// struct (OIDC standard claims plus our typed extensions, and the
17+
// custom_claims sink itself). Derived from the struct's json tags so it
18+
// can't drift when Claims changes.
19+
var standardClaimKeys = jsonKeysOf(reflect.TypeOf(Claims{}))
20+
21+
// jsonKeysOf returns the set of JSON keys declared by a struct type's tags.
22+
func jsonKeysOf(t reflect.Type) map[string]bool {
23+
keys := map[string]bool{}
24+
for i := 0; i < t.NumField(); i++ {
25+
tag, _, _ := strings.Cut(t.Field(i).Tag.Get("json"), ",")
26+
if tag != "" && tag != "-" {
27+
keys[tag] = true
28+
}
29+
}
30+
return keys
31+
}
32+
33+
// captureCustomClaims merges any top-level keys in raw that aren't part of
34+
// the typed Claims struct into c.CustomClaims, so provider-specific claims
35+
// (groups, roles, org_id, …) survive into identity_data downstream. Entries
36+
// already present in c.CustomClaims (e.g. populated by the typed decode of
37+
// a literal "custom_claims" object on the IdP response) win over any
38+
// same-named top-level key.
39+
func captureCustomClaims(raw map[string]interface{}, c *Claims) {
40+
for k, v := range raw {
41+
if standardClaimKeys[k] {
42+
continue
43+
}
44+
if _, exists := c.CustomClaims[k]; exists {
45+
continue
46+
}
47+
if c.CustomClaims == nil {
48+
c.CustomClaims = map[string]interface{}{}
49+
}
50+
c.CustomClaims[k] = v
51+
}
52+
}
53+
54+
// customClaims is a Claims wrapper whose UnmarshalJSON also captures
55+
// non-standard top-level keys under CustomClaims. Scoped to custom OAuth/OIDC
56+
// providers — built-in providers (Google, Apple, …) decode into Claims
57+
// directly and keep their existing behaviour.
58+
type customClaims Claims
59+
60+
func (c *customClaims) UnmarshalJSON(data []byte) error {
61+
var standard Claims
62+
if err := json.Unmarshal(data, &standard); err != nil {
63+
return err
64+
}
65+
*c = customClaims(standard)
66+
67+
var raw map[string]interface{}
68+
if err := json.Unmarshal(data, &raw); err == nil {
69+
captureCustomClaims(raw, (*Claims)(c))
70+
}
71+
return nil
72+
}
73+
1374
// CustomOAuthProvider implements OAuthProvider for custom OAuth2 providers
1475
type CustomOAuthProvider struct {
1576
config *oauth2.Config
@@ -68,10 +129,11 @@ func (p *CustomOAuthProvider) GetOAuthToken(ctx context.Context, code string, op
68129

69130
// GetUserData fetches user data from the provider's userinfo endpoint
70131
func (p *CustomOAuthProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) {
71-
var claims Claims
72-
if err := makeRequest(ctx, tok, p.config, p.userinfoURL, &claims); err != nil {
132+
var decoded customClaims
133+
if err := makeRequest(ctx, tok, p.config, p.userinfoURL, &decoded); err != nil {
73134
return nil, err
74135
}
136+
claims := Claims(decoded)
75137

76138
// Apply attribute mapping if configured
77139
if len(p.attributeMapping) > 0 {
@@ -198,6 +260,16 @@ func (p *CustomOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token)
198260
return nil, err
199261
}
200262

263+
// ParseIDToken decodes into the typed Claims struct and drops anything
264+
// not mapped. Re-walk the raw claims to capture provider-specific
265+
// custom claims (groups, roles, …) for identity_data.
266+
if userData.Metadata != nil {
267+
var raw map[string]interface{}
268+
if err := idTokenObj.Claims(&raw); err == nil {
269+
captureCustomClaims(raw, userData.Metadata)
270+
}
271+
}
272+
201273
// Apply attribute mapping to the metadata from ID token
202274
if len(p.attributeMapping) > 0 && userData.Metadata != nil {
203275
*userData.Metadata = applyAttributeMapping(*userData.Metadata, p.attributeMapping)
@@ -208,10 +280,11 @@ func (p *CustomOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token)
208280

209281
// No ID token, use userinfo endpoint
210282
if p.userinfoEndpoint != "" {
211-
var claims Claims
212-
if err := makeRequest(ctx, tok, p.config, p.userinfoEndpoint, &claims); err != nil {
283+
var decoded customClaims
284+
if err := makeRequest(ctx, tok, p.config, p.userinfoEndpoint, &decoded); err != nil {
213285
return nil, err
214286
}
287+
claims := Claims(decoded)
215288

216289
// Apply attribute mapping
217290
if len(p.attributeMapping) > 0 {

internal/api/provider/custom_oauth_test.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,148 @@ func TestCustomOAuthProvider_GetUserDataWithAttributeMapping(t *testing.T) {
170170
assert.True(t, userData.Emails[0].Verified) // Should be true from literal mapping
171171
}
172172

173+
func TestCustomOAuthProvider_GetUserDataPreservesCustomClaims(t *testing.T) {
174+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
175+
w.Header().Set("Content-Type", "application/json")
176+
json.NewEncoder(w).Encode(map[string]interface{}{
177+
"sub": "user-123",
178+
"email": "test@example.com",
179+
"email_verified": true,
180+
"name": "Test User",
181+
// Non-standard claims that previously got silently dropped.
182+
"groups": []string{"admins", "billing"},
183+
"org_id": "org_42",
184+
"tenant_id": "tenant-abc",
185+
})
186+
}))
187+
defer server.Close()
188+
189+
provider := NewCustomOAuthProvider(
190+
"client-id",
191+
"client-secret",
192+
"https://example.com/authorize",
193+
"https://example.com/token",
194+
server.URL,
195+
"https://myapp.com/callback",
196+
[]string{"openid", "profile", "email"},
197+
false,
198+
nil,
199+
nil,
200+
nil,
201+
)
202+
203+
token := &oauth2.Token{AccessToken: "test-access-token", TokenType: "Bearer"}
204+
userData, err := provider.GetUserData(context.Background(), token)
205+
require.NoError(t, err)
206+
require.NotNil(t, userData)
207+
require.NotNil(t, userData.Metadata)
208+
209+
// Standard fields still populated.
210+
assert.Equal(t, "user-123", userData.Metadata.Subject)
211+
assert.Equal(t, "test@example.com", userData.Metadata.Email)
212+
assert.Equal(t, "Test User", userData.Metadata.Name)
213+
214+
// Non-standard claims preserved under CustomClaims.
215+
require.NotNil(t, userData.Metadata.CustomClaims)
216+
assert.Equal(t, "org_42", userData.Metadata.CustomClaims["org_id"])
217+
assert.Equal(t, "tenant-abc", userData.Metadata.CustomClaims["tenant_id"])
218+
219+
groups, ok := userData.Metadata.CustomClaims["groups"].([]interface{})
220+
require.True(t, ok, "groups should round-trip as []interface{}")
221+
require.Len(t, groups, 2)
222+
assert.Equal(t, "admins", groups[0])
223+
assert.Equal(t, "billing", groups[1])
224+
225+
// Known fields must NOT also leak into CustomClaims.
226+
_, hasEmail := userData.Metadata.CustomClaims["email"]
227+
assert.False(t, hasEmail)
228+
_, hasSub := userData.Metadata.CustomClaims["sub"]
229+
assert.False(t, hasSub)
230+
}
231+
232+
func TestCustomClaimsUnmarshalCapturesCustomClaims(t *testing.T) {
233+
t.Run("standard claims fill typed fields and non-standard claims land in CustomClaims", func(t *testing.T) {
234+
body := []byte(`{
235+
"sub": "u-1",
236+
"email": "a@b.com",
237+
"email_verified": true,
238+
"groups": ["x"],
239+
"org_id": "o1"
240+
}`)
241+
var c customClaims
242+
require.NoError(t, json.Unmarshal(body, &c))
243+
244+
assert.Equal(t, "u-1", c.Subject)
245+
assert.Equal(t, "a@b.com", c.Email)
246+
assert.True(t, c.EmailVerified)
247+
require.NotNil(t, c.CustomClaims)
248+
assert.Equal(t, "o1", c.CustomClaims["org_id"])
249+
assert.Equal(t, []interface{}{"x"}, c.CustomClaims["groups"])
250+
251+
_, hasEmail := c.CustomClaims["email"]
252+
assert.False(t, hasEmail, "standard claims must not also leak into CustomClaims")
253+
})
254+
255+
t.Run("only standard claims means CustomClaims stays nil", func(t *testing.T) {
256+
body := []byte(`{"sub":"u","email":"a@b.com"}`)
257+
var c customClaims
258+
require.NoError(t, json.Unmarshal(body, &c))
259+
assert.Nil(t, c.CustomClaims)
260+
})
261+
262+
t.Run("provider that literally returns custom_claims is preserved flat (not re-nested)", func(t *testing.T) {
263+
body := []byte(`{"sub":"u-2","custom_claims":{"foo":"bar"}}`)
264+
var c customClaims
265+
require.NoError(t, json.Unmarshal(body, &c))
266+
assert.Equal(t, "u-2", c.Subject)
267+
require.NotNil(t, c.CustomClaims)
268+
assert.Equal(t, "bar", c.CustomClaims["foo"])
269+
_, nested := c.CustomClaims["custom_claims"]
270+
assert.False(t, nested, "custom_claims must not be re-nested under itself")
271+
})
272+
273+
t.Run("custom_claims object and other non-standard keys are merged at top level", func(t *testing.T) {
274+
body := []byte(`{
275+
"sub": "u-3",
276+
"custom_claims": {"foo": "bar"},
277+
"groups": ["admins"],
278+
"org_id": "o1"
279+
}`)
280+
var c customClaims
281+
require.NoError(t, json.Unmarshal(body, &c))
282+
assert.Equal(t, "u-3", c.Subject)
283+
require.NotNil(t, c.CustomClaims)
284+
assert.Equal(t, "bar", c.CustomClaims["foo"])
285+
assert.Equal(t, "o1", c.CustomClaims["org_id"])
286+
assert.Equal(t, []interface{}{"admins"}, c.CustomClaims["groups"])
287+
_, nested := c.CustomClaims["custom_claims"]
288+
assert.False(t, nested, "custom_claims must not be re-nested under itself")
289+
})
290+
291+
t.Run("entries inside custom_claims win over same-named top-level keys", func(t *testing.T) {
292+
// IdP returns "groups" both inside custom_claims and at the top
293+
// level. The typed decode places the inner value into CustomClaims
294+
// first; the outer one must not silently overwrite it.
295+
body := []byte(`{
296+
"sub": "u-4",
297+
"custom_claims": {"groups": ["from-inner"]},
298+
"groups": ["from-outer"]
299+
}`)
300+
var c customClaims
301+
require.NoError(t, json.Unmarshal(body, &c))
302+
require.NotNil(t, c.CustomClaims)
303+
assert.Equal(t, []interface{}{"from-inner"}, c.CustomClaims["groups"])
304+
})
305+
306+
t.Run("plain Claims still drops non-standard claims (proves scoping)", func(t *testing.T) {
307+
body := []byte(`{"sub":"u","groups":["x"]}`)
308+
var c Claims
309+
require.NoError(t, json.Unmarshal(body, &c))
310+
assert.Equal(t, "u", c.Subject)
311+
assert.Nil(t, c.CustomClaims, "non-custom providers must keep existing drop behaviour")
312+
})
313+
}
314+
173315
func TestApplyAttributeMapping(t *testing.T) {
174316
tests := []struct {
175317
name string

0 commit comments

Comments
 (0)