Skip to content

Commit f996602

Browse files
sylrclaude
andcommitted
feat(auth): support multiple trusted issuers (#555)
Port of 6b44d66 from release/v2.2 to v3.6, adapted to the v3 oidc package APIs. Allow configuring multiple OIDC issuers so that tokens from different identity providers (e.g. during a domain migration) are all accepted. - Add --auth-issuers flag (comma-separated) alongside existing --auth-issuer - Build one KeySet per issuer via OIDC discovery - Pre-parse tokens to route to the correct issuer's key set - Fail startup when auth is enabled but no issuers are configured Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2c91705 commit f996602

7 files changed

Lines changed: 146 additions & 99 deletions

File tree

pkg/authn/jwt/auth.go

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,28 @@ import (
1010
)
1111

1212
type JWTAuth struct {
13-
issuer string
13+
keySets map[string]oidc.KeySet // issuer -> keySet
1414
checkScopes bool
1515
service string
16-
keySet oidc.KeySet
1716
additionalChecks []AdditionalCheck
1817
}
1918

2019
func NewJWTAuth(
21-
keySet oidc.KeySet,
22-
issuer string,
20+
keySets map[string]oidc.KeySet,
2321
service string,
2422
checkScopes bool,
2523
additionalChecks []AdditionalCheck,
2624
) *JWTAuth {
2725
return &JWTAuth{
28-
issuer: issuer,
26+
keySets: keySets,
2927
checkScopes: checkScopes,
3028
service: service,
31-
keySet: keySet,
3229
additionalChecks: additionalChecks,
3330
}
3431
}
3532

3633
func (ja *JWTAuth) authenticate(r *http.Request) (ControlPlaneAgent, error) {
37-
claims, err := ClaimsFromRequest(r, ja.issuer, ja.keySet)
34+
claims, err := ClaimsFromRequest(r, ja.keySets)
3835
if err != nil {
3936
return nil, err
4037
}
@@ -77,44 +74,42 @@ var (
7774
ErrMalformedHeader = errors.New("malformed authorization header")
7875
)
7976

80-
func ClaimsFromRequest(r *http.Request, expectedIssuer string, keySet oidc.KeySet) (*oidc.AccessTokenClaims, error) {
81-
claims := &oidc.AccessTokenClaims{}
82-
if err := claimsFromRequest(r, claims, keySet); err != nil {
83-
return claims, err
84-
}
85-
86-
if err := oidc.CheckIssuer(claims, expectedIssuer); err != nil {
87-
return claims, err
88-
}
89-
90-
if err := oidc.CheckExpiration(claims, 0); err != nil {
91-
return claims, err
92-
}
93-
94-
return claims, nil
95-
}
96-
97-
func claimsFromRequest[CLAIMS any](r *http.Request, claims CLAIMS, keySet oidc.KeySet) error {
77+
func ClaimsFromRequest(r *http.Request, keySets map[string]oidc.KeySet) (*oidc.AccessTokenClaims, error) {
9878
authHeader := r.Header.Get("authorization")
9979
if authHeader == "" {
100-
return ErrNoAuthorizationHeader
80+
return nil, ErrNoAuthorizationHeader
10181
}
10282

10383
if !strings.HasPrefix(authHeader, "bearer") &&
10484
!strings.HasPrefix(authHeader, "Bearer") {
105-
return ErrMalformedHeader
85+
return nil, ErrMalformedHeader
10686
}
10787

10888
token := authHeader[6:]
10989
token = strings.TrimSpace(token)
11090

91+
claims := &oidc.AccessTokenClaims{}
11192
decrypted, err := oidc.DecryptToken(token)
11293
if err != nil {
113-
return err
94+
return nil, err
11495
}
115-
payload, err := oidc.ParseToken(decrypted, &claims)
96+
payload, err := oidc.ParseToken(decrypted, claims)
11697
if err != nil {
117-
return err
98+
return nil, err
99+
}
100+
101+
keySet, ok := keySets[claims.Issuer]
102+
if !ok {
103+
issuers := make([]string, 0, len(keySets))
104+
for iss := range keySets {
105+
issuers = append(issuers, iss)
106+
}
107+
return claims, fmt.Errorf(
108+
"%w: got: %s, trusted: %v",
109+
oidc.ErrIssuerInvalid,
110+
claims.Issuer,
111+
issuers,
112+
)
118113
}
119114

120115
if _, err = oidc.CheckSignature(
@@ -124,8 +119,12 @@ func claimsFromRequest[CLAIMS any](r *http.Request, claims CLAIMS, keySet oidc.K
124119
[]string{}, // Default to RS256
125120
keySet,
126121
); err != nil {
127-
return err
122+
return claims, err
123+
}
124+
125+
if err := oidc.CheckExpiration(claims, 0); err != nil {
126+
return claims, err
128127
}
129128

130-
return nil
129+
return claims, nil
131130
}

pkg/authn/jwt/auth_test.go

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
151151
t.Parallel()
152152
keySet, privateKey, issuer := setupTestKeySet(t)
153153

154-
auth := NewJWTAuth(keySet, issuer, "test-service", false, []AdditionalCheck{})
154+
auth := NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, []AdditionalCheck{})
155155

156156
// Create access token
157157
token := createAccessToken(t, privateKey, issuer, "", []string{}, "test-user")
@@ -175,11 +175,11 @@ func TestJWTAuth_Authenticate(t *testing.T) {
175175
}{
176176
{
177177
name: "JWTAuth",
178-
auth: NewJWTAuth(keySet, issuer, "test-service", false, []AdditionalCheck{}),
178+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, []AdditionalCheck{}),
179179
},
180180
{
181181
name: "JWTAuth with additional checks",
182-
auth: NewJWTAuth(keySet, issuer, "test-service", false, autoPassingAdditionalChecks),
182+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, autoPassingAdditionalChecks),
183183
},
184184
}
185185

@@ -206,11 +206,11 @@ func TestJWTAuth_Authenticate(t *testing.T) {
206206
}{
207207
{
208208
name: "JWTAuth",
209-
auth: NewJWTAuth(keySet, issuer, "test-service", false, nil),
209+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, nil),
210210
},
211211
{
212212
name: "JWTAuth with additional checks",
213-
auth: NewJWTAuth(keySet, issuer, "test-service", false, autoPassingAdditionalChecks),
213+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, autoPassingAdditionalChecks),
214214
},
215215
}
216216

@@ -237,11 +237,11 @@ func TestJWTAuth_Authenticate(t *testing.T) {
237237
}{
238238
{
239239
name: "JWTAuth",
240-
auth: NewJWTAuth(keySet, issuer, "test-service", false, nil),
240+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, nil),
241241
},
242242
{
243243
name: "JWTAuth with additional checks",
244-
auth: NewJWTAuth(keySet, issuer, "test-service", false, autoPassingAdditionalChecks),
244+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, autoPassingAdditionalChecks),
245245
},
246246
}
247247

@@ -267,11 +267,11 @@ func TestJWTAuth_Authenticate(t *testing.T) {
267267
}{
268268
{
269269
name: "JWTAuth",
270-
auth: NewJWTAuth(keySet, issuer, "test-service", false, nil),
270+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, nil),
271271
},
272272
{
273273
name: "JWTAuth with additional checks",
274-
auth: NewJWTAuth(keySet, issuer, "test-service", false, autoPassingAdditionalChecks),
274+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, autoPassingAdditionalChecks),
275275
},
276276
}
277277

@@ -329,11 +329,11 @@ func TestJWTAuth_Authenticate(t *testing.T) {
329329
}{
330330
{
331331
name: "JWTAuth",
332-
auth: NewJWTAuth(keySet, issuer, "test-service", true, nil),
332+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", true, nil),
333333
},
334334
{
335335
name: "JWTAuth with additional checks",
336-
auth: NewJWTAuth(keySet, issuer, "test-service", true, autoPassingAdditionalChecks),
336+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", true, autoPassingAdditionalChecks),
337337
},
338338
}
339339

@@ -364,11 +364,11 @@ func TestJWTAuth_Authenticate(t *testing.T) {
364364
}{
365365
{
366366
name: "JWTAuth",
367-
auth: NewJWTAuth(keySet, issuer, "test-service", true, nil),
367+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", true, nil),
368368
},
369369
{
370370
name: "JWTAuth with additional checks",
371-
auth: NewJWTAuth(keySet, issuer, "test-service", true, autoPassingAdditionalChecks),
371+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", true, autoPassingAdditionalChecks),
372372
},
373373
}
374374

@@ -398,11 +398,11 @@ func TestJWTAuth_Authenticate(t *testing.T) {
398398
}{
399399
{
400400
name: "JWTAuth",
401-
auth: NewJWTAuth(keySet, issuer, "test-service", true, nil),
401+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", true, nil),
402402
},
403403
{
404404
name: "JWTAuth with additional checks",
405-
auth: NewJWTAuth(keySet, issuer, "test-service", true, autoPassingAdditionalChecks),
405+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", true, autoPassingAdditionalChecks),
406406
},
407407
}
408408

@@ -433,11 +433,11 @@ func TestJWTAuth_Authenticate(t *testing.T) {
433433
}{
434434
{
435435
name: "JWTAuth",
436-
auth: NewJWTAuth(keySet, issuer, "test-service", true, nil),
436+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", true, nil),
437437
},
438438
{
439439
name: "JWTAuth with additional checks",
440-
auth: NewJWTAuth(keySet, issuer, "test-service", true, autoPassingAdditionalChecks),
440+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", true, autoPassingAdditionalChecks),
441441
},
442442
}
443443

@@ -468,11 +468,11 @@ func TestJWTAuth_Authenticate(t *testing.T) {
468468
}{
469469
{
470470
name: "JWTAuth",
471-
auth: NewJWTAuth(keySet, issuer, "test-service", false, nil),
471+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, nil),
472472
},
473473
{
474474
name: "JWTAuth with additional checks",
475-
auth: NewJWTAuth(keySet, issuer, "test-service", false, autoPassingAdditionalChecks),
475+
auth: NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, autoPassingAdditionalChecks),
476476
},
477477
}
478478

@@ -516,7 +516,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
516516
},
517517
}
518518

519-
auth := NewJWTAuth(keySet, issuer, "test-service", false, autoFailingAdditionalChecks)
519+
auth := NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, autoFailingAdditionalChecks)
520520

521521
// Create access token
522522
token := createAccessToken(t, privateKey, issuer, "", []string{}, "test-user")
@@ -543,7 +543,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
543543
CheckOrganizationIDClaim(provider),
544544
}
545545

546-
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
546+
auth := NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, additionalChecks)
547547

548548
// Create access token
549549
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", expectedOrgID)
@@ -566,7 +566,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
566566
additionalChecks := []AdditionalCheck{
567567
CheckOrganizationIDClaim(provider),
568568
}
569-
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
569+
auth := NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, additionalChecks)
570570

571571
// Create access token
572572
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "")
@@ -590,7 +590,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
590590
additionalChecks := []AdditionalCheck{
591591
CheckOrganizationIDClaim(provider),
592592
}
593-
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
593+
auth := NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, additionalChecks)
594594

595595
// Create access token
596596
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "someotherorgid")
@@ -615,7 +615,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
615615
additionalChecks := []AdditionalCheck{
616616
CheckOrganizationIDClaim(provider),
617617
}
618-
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
618+
auth := NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, additionalChecks)
619619

620620
// Create access token
621621
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "")
@@ -639,7 +639,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
639639
additionalChecks := []AdditionalCheck{
640640
CheckAudienceClaim(expectedAudience),
641641
}
642-
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
642+
auth := NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, additionalChecks)
643643

644644
// Create access token
645645
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "")
@@ -663,7 +663,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
663663
additionalChecks := []AdditionalCheck{
664664
CheckAudienceClaim(expectedAudience),
665665
}
666-
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
666+
auth := NewJWTAuth(map[string]oidc.KeySet{issuer: keySet}, "test-service", false, additionalChecks)
667667

668668
// Create access token
669669
tokenAudience := expectedAudience

pkg/authn/jwt/flags.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@ import (
77
const (
88
AuthEnabledFlag = "auth-enabled"
99
AuthIssuerFlag = "auth-issuer"
10+
AuthIssuersFlag = "auth-issuers"
1011
AuthReadKeySetMaxRetriesFlag = "auth-read-key-set-max-retries"
1112
AuthCheckScopesFlag = "auth-check-scopes"
1213
AuthServiceFlag = "auth-service"
1314
)
1415

1516
func AddFlags(flags *flag.FlagSet) {
1617
flags.Bool(AuthEnabledFlag, false, "Enable auth")
17-
flags.String(AuthIssuerFlag, "", "Issuer")
18+
flags.String(AuthIssuerFlag, "", "Issuer (single issuer, for backward compatibility)")
19+
flags.StringSlice(AuthIssuersFlag, nil, "Trusted issuers (comma-separated, e.g. --auth-issuers=https://issuer1,https://issuer2)")
1820
flags.Int(AuthReadKeySetMaxRetriesFlag, 10, "ReadKeySetMaxRetries")
1921
flags.Bool(AuthCheckScopesFlag, false, "CheckScopes")
2022
flags.String(AuthServiceFlag, "", "Service")
@@ -23,13 +25,28 @@ func AddFlags(flags *flag.FlagSet) {
2325
func ConfigFromFlags(flags *flag.FlagSet) Config {
2426
authEnabled, _ := flags.GetBool(AuthEnabledFlag)
2527
authIssuer, _ := flags.GetString(AuthIssuerFlag)
28+
authIssuers, _ := flags.GetStringSlice(AuthIssuersFlag)
2629
authReadKeySetMaxRetries, _ := flags.GetInt(AuthReadKeySetMaxRetriesFlag)
2730
authCheckScopes, _ := flags.GetBool(AuthCheckScopesFlag)
2831
authService, _ := flags.GetString(AuthServiceFlag)
2932

33+
// Merge --auth-issuer into --auth-issuers for backward compatibility
34+
if authIssuer != "" {
35+
found := false
36+
for _, iss := range authIssuers {
37+
if iss == authIssuer {
38+
found = true
39+
break
40+
}
41+
}
42+
if !found {
43+
authIssuers = append(authIssuers, authIssuer)
44+
}
45+
}
46+
3047
return Config{
3148
Enabled: authEnabled,
32-
Issuer: authIssuer,
49+
Issuers: authIssuers,
3350
ReadKeySetMaxRetries: authReadKeySetMaxRetries,
3451
CheckScopes: authCheckScopes,
3552
Service: authService,

0 commit comments

Comments
 (0)