Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions auth/additional_checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,18 @@ func CheckOrganizationIDClaim(fn OrganizationIDProvider) AdditionalCheck {
return nil
}
}

func CheckAudienceClaim(expectedAudienceUrl string) AdditionalCheck {
return func(_ *http.Request, claims *oidc.AccessTokenClaims) error {
if claims == nil {
return fmt.Errorf("claims cannot be nil")
}

for _, aud := range claims.GetAudience() {
if aud == expectedAudienceUrl {
return nil
}
}
return oidc.ErrAudience
}
}
68 changes: 68 additions & 0 deletions auth/additional_checks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package auth_test

import (
"errors"
"testing"

"github.com/stretchr/testify/assert"

"github.com/formancehq/go-libs/v3/auth"
"github.com/formancehq/go-libs/v3/oidc"
)

func TestCheckAudienceClaim(t *testing.T) {
tests := map[string]struct {
expectedAudienceStr string
claims *oidc.AccessTokenClaims
expectedError error
}{
"NilClaims": {
claims: nil,
expectedError: errors.New("claims cannot be nil"),
},
"MatchingAudience with url scheme": {
expectedAudienceStr: "http://example.com",
claims: &oidc.AccessTokenClaims{
TokenClaims: oidc.TokenClaims{
Audience: []string{"http://example.com"},
},
},
expectedError: nil,
},
"NonMatchingAudience with url scheme": {
expectedAudienceStr: "http://example.com",
claims: &oidc.AccessTokenClaims{
TokenClaims: oidc.TokenClaims{
Audience: []string{"http://another.com"},
},
},
expectedError: oidc.ErrAudience,
},
"Multiple audiences in claim; one matches": {
expectedAudienceStr: "example.com",
claims: &oidc.AccessTokenClaims{
TokenClaims: oidc.TokenClaims{
Audience: []string{"otherdomain.com", "example.com", "123.com"},
},
},
expectedError: nil,
},
"Multiple audiences in claim but none match": {
expectedAudienceStr: "http://example.com",
claims: &oidc.AccessTokenClaims{
TokenClaims: oidc.TokenClaims{
Audience: []string{"another.com", "ple.com", "subdomain.example.com"},
},
},
expectedError: oidc.ErrAudience,
},
}

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
check := auth.CheckAudienceClaim(tt.expectedAudienceStr)
err := check(nil, tt.claims)
assert.Equal(t, tt.expectedError, err)
})
}
}
32 changes: 26 additions & 6 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package auth

import (
"errors"
"fmt"
"net/http"
"strings"

Expand Down Expand Up @@ -32,24 +33,43 @@ func NewJWTAuth(
}
}

// Authenticate validates the JWT in the request and returns the user, if valid.
func (ja *JWTAuth) Authenticate(_ http.ResponseWriter, r *http.Request) (bool, error) {
func (ja *JWTAuth) authenticate(r *http.Request) (ControlPlaneAgent, error) {
claims, err := ClaimsFromRequest(r, ja.issuer, ja.keySet)
if err != nil {
return false, err
return nil, err
}
Comment thread
laouji marked this conversation as resolved.

// DefaultControlPlaneAgent provides access to claims that are expected to be present when authenticating via the Control Plane
// in the case of another issuer (eg. Stack authentication) some of these claims may not be present
agt := NewDefaultControlPlaneAgent(*claims)
for _, check := range ja.additionalChecks {
err := check(r, claims)
if err != nil {
return false, err
return agt, err
}
}

if !ja.checkScopes {
return true, nil
return agt, nil
}
valid, err := checkScopes(ja.service, r.Method, claims.Scopes)
if err != nil || !valid {
return agt, fmt.Errorf("scopes not valid: %w", err)
}
return agt, nil
}

func (ja *JWTAuth) AuthenticateOnControlPlane(r *http.Request) (ControlPlaneAgent, error) {
return ja.authenticate(r)
}

// Authenticate validates the JWT in the request and returns the user, if valid.
func (ja *JWTAuth) Authenticate(_ http.ResponseWriter, r *http.Request) (bool, error) {
_, err := ja.authenticate(r)
if err != nil {
return false, err
}
return checkScopes(ja.service, r.Method, claims.Scopes)
return true, nil
}

var (
Expand Down
87 changes: 73 additions & 14 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,19 @@ func setupTestKeySet(t *testing.T) (oidc.KeySet, *rsa.PrivateKey, string) {
return keySet, privateKey, issuer
}

func createAccessToken(t *testing.T, privateKey *rsa.PrivateKey, issuer string, scopes []string, subject string) string {
func createAccessToken(t *testing.T, privateKey *rsa.PrivateKey, issuer string, audience string, scopes []string, subject string) string {
now := stdtime.Now().UTC()
expirationTime := libtime.New(now.Add(1 * stdtime.Hour))

audiences := make([]string, 0, 1)
if audience != "" {
audiences = append(audiences, audience)
}

accessTokenClaims := oidc.NewAccessTokenClaims(
issuer,
subject,
[]string{"test-client"},
audiences,
expirationTime,
"test-jti",
"test-client",
Expand Down Expand Up @@ -82,17 +87,23 @@ func createAccessTokenWithOrgClaims(
t *testing.T,
privateKey *rsa.PrivateKey,
issuer string,
audience string,
scopes []string,
subject string,
organizationID string,
) string {
now := stdtime.Now().UTC()
expirationTime := libtime.New(now.Add(1 * stdtime.Hour))

audiences := make([]string, 0, 1)
if audience != "" {
audiences = append(audiences, audience)
}

accessTokenClaims := oidc.NewOrganizationAwareAccessTokenClaims(
issuer,
subject,
[]string{"test-client"},
audiences,
expirationTime,
"test-jti",
"test-client",
Expand Down Expand Up @@ -143,7 +154,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
auth := NewJWTAuth(keySet, issuer, "test-service", false, []AdditionalCheck{})

// Create access token
token := createAccessToken(t, privateKey, issuer, []string{}, "test-user")
token := createAccessToken(t, privateKey, issuer, "", []string{}, "test-user")

// Create request with valid token
req := httptest.NewRequest("GET", "/test", nil)
Expand Down Expand Up @@ -330,7 +341,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {

// Create access token with read scope
token := createAccessToken(t, privateKey, issuer, []string{"test-service:read"}, "test-user")
token := createAccessToken(t, privateKey, issuer, "", []string{"test-service:read"}, "test-user")

req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
Expand Down Expand Up @@ -364,7 +375,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create access token with write scope
token := createAccessToken(t, privateKey, issuer, []string{"test-service:write"}, "test-user")
token := createAccessToken(t, privateKey, issuer, "", []string{"test-service:write"}, "test-user")

req := httptest.NewRequest("POST", "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
Expand Down Expand Up @@ -398,7 +409,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create access token with only read scope (not enough for POST)
token := createAccessToken(t, privateKey, issuer, []string{"test-service:read"}, "test-user")
token := createAccessToken(t, privateKey, issuer, "", []string{"test-service:read"}, "test-user")

req := httptest.NewRequest("POST", "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
Expand Down Expand Up @@ -433,7 +444,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create access token with write scope
token := createAccessToken(t, privateKey, issuer, []string{"test-service:write"}, "test-user")
token := createAccessToken(t, privateKey, issuer, "", []string{"test-service:write"}, "test-user")

req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
Expand Down Expand Up @@ -468,7 +479,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create access token
token := createAccessToken(t, privateKey, unexpectedIssuer, []string{}, "test-user")
token := createAccessToken(t, privateKey, unexpectedIssuer, "", []string{}, "test-user")

// Create request with valid token
req := httptest.NewRequest("GET", "/test", nil)
Expand Down Expand Up @@ -508,7 +519,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
auth := NewJWTAuth(keySet, issuer, "test-service", false, autoFailingAdditionalChecks)

// Create access token
token := createAccessToken(t, privateKey, issuer, []string{}, "test-user")
token := createAccessToken(t, privateKey, issuer, "", []string{}, "test-user")

// Create request with valid token
req := httptest.NewRequest("GET", "/test", nil)
Expand All @@ -535,7 +546,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)

// Create access token
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", expectedOrgID)
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", expectedOrgID)

// Create request with valid token
req := httptest.NewRequest("GET", "/test", nil)
Expand All @@ -558,7 +569,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)

// Create access token
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", "")
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "")

// Create request with valid token
req := httptest.NewRequest("GET", "/test", nil)
Expand All @@ -582,7 +593,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)

// Create access token
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", "someotherorgid")
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "someotherorgid")

// Create request with valid token
req := httptest.NewRequest("GET", "/test", nil)
Expand All @@ -607,7 +618,7 @@ func TestJWTAuth_Authenticate(t *testing.T) {
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)

// Create access token
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", "")
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "")

// Create request with valid token
req := httptest.NewRequest("GET", "/test", nil)
Expand All @@ -619,4 +630,52 @@ func TestJWTAuth_Authenticate(t *testing.T) {
assert.ErrorIs(t, err, oidc.ErrOrgIDNotPresent)
assert.False(t, authenticated)
})

t.Run("CheckAudienceClaim audience mismatches", func(t *testing.T) {
t.Parallel()
keySet, privateKey, issuer := setupTestKeySet(t)
expectedAudience := "http://expected.mydomain.com"

additionalChecks := []AdditionalCheck{
CheckAudienceClaim(expectedAudience),
}
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)

// Create access token
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, "", []string{}, "test-user", "")

// Create request with valid token
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
req = req.WithContext(logging.TestingContext())

authenticated, err := auth.Authenticate(nil, req)
require.Error(t, err)
assert.ErrorIs(t, err, oidc.ErrAudience)
assert.False(t, authenticated)
})

t.Run("CheckAudienceClaim audience matches", func(t *testing.T) {
t.Parallel()
keySet, privateKey, issuer := setupTestKeySet(t)
expectedAudience := "http://expected.mydomain.com"

additionalChecks := []AdditionalCheck{
CheckAudienceClaim(expectedAudience),
}
auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks)

// Create access token
tokenAudience := expectedAudience
token := createAccessTokenWithOrgClaims(t, privateKey, issuer, tokenAudience, []string{}, "test-user", "")

// Create request with valid token
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
req = req.WithContext(logging.TestingContext())

authenticated, err := auth.Authenticate(nil, req)
require.NoError(t, err)
assert.True(t, authenticated)
})
}
9 changes: 9 additions & 0 deletions auth/authenticator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package auth

import "net/http"

//go:generate mockgen -source authenticator.go -destination authenticator_generated.go -package auth . Authenticator
type Authenticator interface {
Authenticate(w http.ResponseWriter, r *http.Request) (bool, error)
AuthenticateOnControlPlane(r *http.Request) (ControlPlaneAgent, error)
}
19 changes: 17 additions & 2 deletions auth/authenticator_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading