Skip to content

Commit 79fe6a4

Browse files
authored
feat(auth): Inject auth claims into request context and authenticate audience (#540)
* feat(auth): Inject auth claims into request context * feat(auth): Add audience verification as additional check * Split between Stack middleware and Control Plane middleware to separate concerns * Add AnnotatedModule function * Use exact string match for audience check * Add debug log for new auth middleware
1 parent a78dedb commit 79fe6a4

14 files changed

Lines changed: 698 additions & 71 deletions

auth/additional_checks.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,18 @@ func CheckOrganizationIDClaim(fn OrganizationIDProvider) AdditionalCheck {
4242
return nil
4343
}
4444
}
45+
46+
func CheckAudienceClaim(expectedAudienceUrl string) AdditionalCheck {
47+
return func(_ *http.Request, claims *oidc.AccessTokenClaims) error {
48+
if claims == nil {
49+
return fmt.Errorf("claims cannot be nil")
50+
}
51+
52+
for _, aud := range claims.GetAudience() {
53+
if aud == expectedAudienceUrl {
54+
return nil
55+
}
56+
}
57+
return oidc.ErrAudience
58+
}
59+
}

auth/additional_checks_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package auth_test
2+
3+
import (
4+
"errors"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
9+
"github.com/formancehq/go-libs/v3/auth"
10+
"github.com/formancehq/go-libs/v3/oidc"
11+
)
12+
13+
func TestCheckAudienceClaim(t *testing.T) {
14+
tests := map[string]struct {
15+
expectedAudienceStr string
16+
claims *oidc.AccessTokenClaims
17+
expectedError error
18+
}{
19+
"NilClaims": {
20+
claims: nil,
21+
expectedError: errors.New("claims cannot be nil"),
22+
},
23+
"MatchingAudience with url scheme": {
24+
expectedAudienceStr: "http://example.com",
25+
claims: &oidc.AccessTokenClaims{
26+
TokenClaims: oidc.TokenClaims{
27+
Audience: []string{"http://example.com"},
28+
},
29+
},
30+
expectedError: nil,
31+
},
32+
"NonMatchingAudience with url scheme": {
33+
expectedAudienceStr: "http://example.com",
34+
claims: &oidc.AccessTokenClaims{
35+
TokenClaims: oidc.TokenClaims{
36+
Audience: []string{"http://another.com"},
37+
},
38+
},
39+
expectedError: oidc.ErrAudience,
40+
},
41+
"Multiple audiences in claim; one matches": {
42+
expectedAudienceStr: "example.com",
43+
claims: &oidc.AccessTokenClaims{
44+
TokenClaims: oidc.TokenClaims{
45+
Audience: []string{"otherdomain.com", "example.com", "123.com"},
46+
},
47+
},
48+
expectedError: nil,
49+
},
50+
"Multiple audiences in claim but none match": {
51+
expectedAudienceStr: "http://example.com",
52+
claims: &oidc.AccessTokenClaims{
53+
TokenClaims: oidc.TokenClaims{
54+
Audience: []string{"another.com", "ple.com", "subdomain.example.com"},
55+
},
56+
},
57+
expectedError: oidc.ErrAudience,
58+
},
59+
}
60+
61+
for name, tt := range tests {
62+
t.Run(name, func(t *testing.T) {
63+
check := auth.CheckAudienceClaim(tt.expectedAudienceStr)
64+
err := check(nil, tt.claims)
65+
assert.Equal(t, tt.expectedError, err)
66+
})
67+
}
68+
}

auth/auth.go

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package auth
22

33
import (
44
"errors"
5+
"fmt"
56
"net/http"
67
"strings"
78

@@ -32,24 +33,43 @@ func NewJWTAuth(
3233
}
3334
}
3435

35-
// Authenticate validates the JWT in the request and returns the user, if valid.
36-
func (ja *JWTAuth) Authenticate(_ http.ResponseWriter, r *http.Request) (bool, error) {
36+
func (ja *JWTAuth) authenticate(r *http.Request) (ControlPlaneAgent, error) {
3737
claims, err := ClaimsFromRequest(r, ja.issuer, ja.keySet)
3838
if err != nil {
39-
return false, err
39+
return nil, err
4040
}
4141

42+
// DefaultControlPlaneAgent provides access to claims that are expected to be present when authenticating via the Control Plane
43+
// in the case of another issuer (eg. Stack authentication) some of these claims may not be present
44+
agt := NewDefaultControlPlaneAgent(*claims)
4245
for _, check := range ja.additionalChecks {
4346
err := check(r, claims)
4447
if err != nil {
45-
return false, err
48+
return agt, err
4649
}
4750
}
4851

4952
if !ja.checkScopes {
50-
return true, nil
53+
return agt, nil
54+
}
55+
valid, err := checkScopes(ja.service, r.Method, claims.Scopes)
56+
if err != nil || !valid {
57+
return agt, fmt.Errorf("scopes not valid: %w", err)
58+
}
59+
return agt, nil
60+
}
61+
62+
func (ja *JWTAuth) AuthenticateOnControlPlane(r *http.Request) (ControlPlaneAgent, error) {
63+
return ja.authenticate(r)
64+
}
65+
66+
// Authenticate validates the JWT in the request and returns the user, if valid.
67+
func (ja *JWTAuth) Authenticate(_ http.ResponseWriter, r *http.Request) (bool, error) {
68+
_, err := ja.authenticate(r)
69+
if err != nil {
70+
return false, err
5171
}
52-
return checkScopes(ja.service, r.Method, claims.Scopes)
72+
return true, nil
5373
}
5474

5575
var (

auth/auth_test.go

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,19 @@ func setupTestKeySet(t *testing.T) (oidc.KeySet, *rsa.PrivateKey, string) {
4040
return keySet, privateKey, issuer
4141
}
4242

43-
func createAccessToken(t *testing.T, privateKey *rsa.PrivateKey, issuer string, scopes []string, subject string) string {
43+
func createAccessToken(t *testing.T, privateKey *rsa.PrivateKey, issuer string, audience string, scopes []string, subject string) string {
4444
now := stdtime.Now().UTC()
4545
expirationTime := libtime.New(now.Add(1 * stdtime.Hour))
4646

47+
audiences := make([]string, 0, 1)
48+
if audience != "" {
49+
audiences = append(audiences, audience)
50+
}
51+
4752
accessTokenClaims := oidc.NewAccessTokenClaims(
4853
issuer,
4954
subject,
50-
[]string{"test-client"},
55+
audiences,
5156
expirationTime,
5257
"test-jti",
5358
"test-client",
@@ -82,17 +87,23 @@ func createAccessTokenWithOrgClaims(
8287
t *testing.T,
8388
privateKey *rsa.PrivateKey,
8489
issuer string,
90+
audience string,
8591
scopes []string,
8692
subject string,
8793
organizationID string,
8894
) string {
8995
now := stdtime.Now().UTC()
9096
expirationTime := libtime.New(now.Add(1 * stdtime.Hour))
9197

98+
audiences := make([]string, 0, 1)
99+
if audience != "" {
100+
audiences = append(audiences, audience)
101+
}
102+
92103
accessTokenClaims := oidc.NewOrganizationAwareAccessTokenClaims(
93104
issuer,
94105
subject,
95-
[]string{"test-client"},
106+
audiences,
96107
expirationTime,
97108
"test-jti",
98109
"test-client",
@@ -143,7 +154,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
143154
auth := NewJWTAuth(keySet, issuer, "test-service", false, []AdditionalCheck{})
144155

145156
// Create access token
146-
token := createAccessToken(t, privateKey, issuer, []string{}, "test-user")
157+
token := createAccessToken(t, privateKey, issuer, "", []string{}, "test-user")
147158

148159
// Create request with valid token
149160
req := httptest.NewRequest("GET", "/test", nil)
@@ -330,7 +341,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
330341
t.Run(tt.name, func(t *testing.T) {
331342

332343
// Create access token with read scope
333-
token := createAccessToken(t, privateKey, issuer, []string{"test-service:read"}, "test-user")
344+
token := createAccessToken(t, privateKey, issuer, "", []string{"test-service:read"}, "test-user")
334345

335346
req := httptest.NewRequest("GET", "/test", nil)
336347
req.Header.Set("Authorization", "Bearer "+token)
@@ -364,7 +375,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
364375
for _, tt := range tests {
365376
t.Run(tt.name, func(t *testing.T) {
366377
// Create access token with write scope
367-
token := createAccessToken(t, privateKey, issuer, []string{"test-service:write"}, "test-user")
378+
token := createAccessToken(t, privateKey, issuer, "", []string{"test-service:write"}, "test-user")
368379

369380
req := httptest.NewRequest("POST", "/test", nil)
370381
req.Header.Set("Authorization", "Bearer "+token)
@@ -398,7 +409,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
398409
for _, tt := range tests {
399410
t.Run(tt.name, func(t *testing.T) {
400411
// Create access token with only read scope (not enough for POST)
401-
token := createAccessToken(t, privateKey, issuer, []string{"test-service:read"}, "test-user")
412+
token := createAccessToken(t, privateKey, issuer, "", []string{"test-service:read"}, "test-user")
402413

403414
req := httptest.NewRequest("POST", "/test", nil)
404415
req.Header.Set("Authorization", "Bearer "+token)
@@ -433,7 +444,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
433444
for _, tt := range tests {
434445
t.Run(tt.name, func(t *testing.T) {
435446
// Create access token with write scope
436-
token := createAccessToken(t, privateKey, issuer, []string{"test-service:write"}, "test-user")
447+
token := createAccessToken(t, privateKey, issuer, "", []string{"test-service:write"}, "test-user")
437448

438449
req := httptest.NewRequest("GET", "/test", nil)
439450
req.Header.Set("Authorization", "Bearer "+token)
@@ -468,7 +479,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
468479
for _, tt := range tests {
469480
t.Run(tt.name, func(t *testing.T) {
470481
// Create access token
471-
token := createAccessToken(t, privateKey, unexpectedIssuer, []string{}, "test-user")
482+
token := createAccessToken(t, privateKey, unexpectedIssuer, "", []string{}, "test-user")
472483

473484
// Create request with valid token
474485
req := httptest.NewRequest("GET", "/test", nil)
@@ -508,7 +519,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
508519
auth := NewJWTAuth(keySet, issuer, "test-service", false, autoFailingAdditionalChecks)
509520

510521
// Create access token
511-
token := createAccessToken(t, privateKey, issuer, []string{}, "test-user")
522+
token := createAccessToken(t, privateKey, issuer, "", []string{}, "test-user")
512523

513524
// Create request with valid token
514525
req := httptest.NewRequest("GET", "/test", nil)
@@ -535,7 +546,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
535546
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
536547

537548
// Create access token
538-
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", expectedOrgID)
549+
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", expectedOrgID)
539550

540551
// Create request with valid token
541552
req := httptest.NewRequest("GET", "/test", nil)
@@ -558,7 +569,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
558569
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
559570

560571
// Create access token
561-
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", "")
572+
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "")
562573

563574
// Create request with valid token
564575
req := httptest.NewRequest("GET", "/test", nil)
@@ -582,7 +593,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
582593
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
583594

584595
// Create access token
585-
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", "someotherorgid")
596+
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "someotherorgid")
586597

587598
// Create request with valid token
588599
req := httptest.NewRequest("GET", "/test", nil)
@@ -607,7 +618,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
607618
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
608619

609620
// Create access token
610-
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", "")
621+
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "")
611622

612623
// Create request with valid token
613624
req := httptest.NewRequest("GET", "/test", nil)
@@ -619,4 +630,52 @@ func TestJWTAuth_Authenticate(t *testing.T) {
619630
assert.ErrorIs(t, err, oidc.ErrOrgIDNotPresent)
620631
assert.False(t, authenticated)
621632
})
633+
634+
t.Run("CheckAudienceClaim audience mismatches", func(t *testing.T) {
635+
t.Parallel()
636+
keySet, privateKey, issuer := setupTestKeySet(t)
637+
expectedAudience := "http://expected.mydomain.com"
638+
639+
additionalChecks := []AdditionalCheck{
640+
CheckAudienceClaim(expectedAudience),
641+
}
642+
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
643+
644+
// Create access token
645+
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "")
646+
647+
// Create request with valid token
648+
req := httptest.NewRequest("GET", "/test", nil)
649+
req.Header.Set("Authorization", "Bearer "+token)
650+
req = req.WithContext(logging.TestingContext())
651+
652+
authenticated, err := auth.Authenticate(nil, req)
653+
require.Error(t, err)
654+
assert.ErrorIs(t, err, oidc.ErrAudience)
655+
assert.False(t, authenticated)
656+
})
657+
658+
t.Run("CheckAudienceClaim audience matches", func(t *testing.T) {
659+
t.Parallel()
660+
keySet, privateKey, issuer := setupTestKeySet(t)
661+
expectedAudience := "http://expected.mydomain.com"
662+
663+
additionalChecks := []AdditionalCheck{
664+
CheckAudienceClaim(expectedAudience),
665+
}
666+
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)
667+
668+
// Create access token
669+
tokenAudience := expectedAudience
670+
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, tokenAudience, []string{}, "test-user", "")
671+
672+
// Create request with valid token
673+
req := httptest.NewRequest("GET", "/test", nil)
674+
req.Header.Set("Authorization", "Bearer "+token)
675+
req = req.WithContext(logging.TestingContext())
676+
677+
authenticated, err := auth.Authenticate(nil, req)
678+
require.NoError(t, err)
679+
assert.True(t, authenticated)
680+
})
622681
}

auth/authenticator.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package auth
2+
3+
import "net/http"
4+
5+
//go:generate mockgen -source authenticator.go -destination authenticator_generated.go -package auth . Authenticator
6+
type Authenticator interface {
7+
Authenticate(w http.ResponseWriter, r *http.Request) (bool, error)
8+
AuthenticateOnControlPlane(r *http.Request) (ControlPlaneAgent, error)
9+
}

auth/authenticator_generated.go

Lines changed: 17 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)