diff --git a/auth/additional_checks.go b/auth/additional_checks.go new file mode 100644 index 00000000..e41ea40f --- /dev/null +++ b/auth/additional_checks.go @@ -0,0 +1,44 @@ +package auth + +import ( + "fmt" + "net/http" + + "github.com/formancehq/go-libs/v3/oidc" +) + +type AdditionalCheck func(*http.Request, *oidc.AccessTokenClaims) error + +// OrganizationIDProvider should give the authorizer the ability +// to know what orgID (if any) is associated with the resource the requester is attempting to access +// if no orgID is required, a blank string can be returned +type OrganizationIDProvider func(*http.Request) (orgID string, err error) + +func CheckOrganizationIDClaim(fn OrganizationIDProvider) AdditionalCheck { + return func(r *http.Request, rawClaims *oidc.AccessTokenClaims) error { + if rawClaims == nil { + return fmt.Errorf("claims cannot be nil") + } + claims := &oidc.OrganizationAwareAccessTokenClaims{AccessTokenClaims: *rawClaims} + + expectedOrgID, err := fn(r) + if err != nil { + return err + } + + // if the endpoint doesn't require a particular orgID we consider it valid + if expectedOrgID == "" { + return nil + } + + orgID := claims.GetOrganizationID() + if orgID == "" { + return oidc.ErrOrgIDNotPresent + } + + if expectedOrgID != "" && orgID != expectedOrgID { + return oidc.ErrOrgIDInvalid + } + return nil + } +} diff --git a/auth/auth.go b/auth/auth.go index 7e406301..70ba55c0 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -2,19 +2,18 @@ package auth import ( "errors" - "fmt" "net/http" "strings" - "github.com/formancehq/go-libs/v3/collectionutils" "github.com/formancehq/go-libs/v3/oidc" ) type JWTAuth struct { - issuer string - checkScopes bool - service string - keySet oidc.KeySet + issuer string + checkScopes bool + service string + keySet oidc.KeySet + additionalChecks []AdditionalCheck } func NewJWTAuth( @@ -22,41 +21,35 @@ func NewJWTAuth( issuer string, service string, checkScopes bool, + additionalChecks []AdditionalCheck, ) *JWTAuth { return &JWTAuth{ - issuer: issuer, - checkScopes: checkScopes, - service: service, - keySet: keySet, + issuer: issuer, + checkScopes: checkScopes, + service: service, + keySet: keySet, + additionalChecks: additionalChecks, } } // Authenticate validates the JWT in the request and returns the user, if valid. func (ja *JWTAuth) Authenticate(_ http.ResponseWriter, r *http.Request) (bool, error) { - claims, err := ClaimsFromRequest(r, ja.issuer, ja.keySet) if err != nil { return false, err } - if ja.checkScopes { - scope := claims.Scopes - - allowed := true //nolint:ineffassign - switch r.Method { - case http.MethodOptions, http.MethodGet, http.MethodHead, http.MethodTrace: - allowed = collectionutils.Contains(scope, ja.service+":read") || - collectionutils.Contains(scope, ja.service+":write") - default: - allowed = collectionutils.Contains(scope, ja.service+":write") - } - - if !allowed { - return false, fmt.Errorf("missing access, found scopes: '%s' need %s:read|write", strings.Join(scope, ", "), ja.service) + for _, check := range ja.additionalChecks { + err := check(r, claims) + if err != nil { + return false, err } } - return true, nil + if !ja.checkScopes { + return true, nil + } + return checkScopes(ja.service, r.Method, claims.Scopes) } var ( @@ -65,32 +58,43 @@ var ( ) func ClaimsFromRequest(r *http.Request, expectedIssuer string, keySet oidc.KeySet) (*oidc.AccessTokenClaims, error) { + claims := &oidc.AccessTokenClaims{} + if err := claimsFromRequest(r, claims, keySet); err != nil { + return claims, err + } + + if err := oidc.CheckIssuer(claims, expectedIssuer); err != nil { + return claims, err + } + + if err := oidc.CheckExpiration(claims, 0); err != nil { + return claims, err + } + + return claims, nil +} +func claimsFromRequest[CLAIMS any](r *http.Request, claims CLAIMS, keySet oidc.KeySet) error { authHeader := r.Header.Get("authorization") if authHeader == "" { - return nil, ErrNoAuthorizationHeader + return ErrNoAuthorizationHeader } if !strings.HasPrefix(authHeader, "bearer") && !strings.HasPrefix(authHeader, "Bearer") { - return nil, ErrMalformedHeader + return ErrMalformedHeader } token := authHeader[6:] token = strings.TrimSpace(token) - claims := &oidc.AccessTokenClaims{} decrypted, err := oidc.DecryptToken(token) if err != nil { - return nil, err + return err } payload, err := oidc.ParseToken(decrypted, &claims) if err != nil { - return nil, err - } - - if err := oidc.CheckIssuer(claims, expectedIssuer); err != nil { - return claims, err + return err } if _, err = oidc.CheckSignature( @@ -100,12 +104,8 @@ func ClaimsFromRequest(r *http.Request, expectedIssuer string, keySet oidc.KeySe []string{}, // Default to RS256 keySet, ); err != nil { - return claims, err + return err } - if err = oidc.CheckExpiration(claims, 0); err != nil { - return claims, err - } - - return claims, nil + return nil } diff --git a/auth/auth_test.go b/auth/auth_test.go index 6fa2e58d..9fb8eaf9 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -4,11 +4,14 @@ import ( "crypto/rand" "crypto/rsa" "encoding/json" + "errors" + "net/http" "net/http/httptest" "testing" stdtime "time" "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/formancehq/go-libs/v3/logging" @@ -75,14 +78,69 @@ func createAccessToken(t *testing.T, privateKey *rsa.PrivateKey, issuer string, return token } +func createAccessTokenWithOrgClaims( + t *testing.T, + privateKey *rsa.PrivateKey, + issuer string, + scopes []string, + subject string, + organizationID string, +) string { + now := stdtime.Now().UTC() + expirationTime := libtime.New(now.Add(1 * stdtime.Hour)) + + accessTokenClaims := oidc.NewOrganizationAwareAccessTokenClaims( + issuer, + subject, + []string{"test-client"}, + expirationTime, + "test-jti", + "test-client", + ) + + // Set scopes + accessTokenClaims.Scopes = scopes + + privateClaims := map[string]interface{}{} + if organizationID != "" { + privateClaims[oidc.ClaimOrganizationID] = organizationID + } + accessTokenClaims.Claims = privateClaims + + // Create JWT using go-jose + signer, err := jose.NewSigner( + jose.SigningKey{ + Algorithm: jose.RS256, + Key: privateKey, + }, + (&jose.SignerOptions{}).WithHeader("kid", "test-key-id"), + ) + require.NoError(t, err) + + claimsJSON, err := accessTokenClaims.MarshalJSON() + require.NoError(t, err) + + signed, err := signer.Sign(claimsJSON) + require.NoError(t, err) + + token, err := signed.CompactSerialize() + require.NoError(t, err) + + return token +} + func TestJWTAuth_Authenticate(t *testing.T) { t.Parallel() + autoPassingAdditionalChecks := []AdditionalCheck{ + func(*http.Request, *oidc.AccessTokenClaims) error { return nil }, + } + t.Run("success with valid token", func(t *testing.T) { t.Parallel() keySet, privateKey, issuer := setupTestKeySet(t) - auth := NewJWTAuth(keySet, issuer, "test-service", false) + auth := NewJWTAuth(keySet, issuer, "test-service", false, []AdditionalCheck{}) // Create access token token := createAccessToken(t, privateKey, issuer, []string{}, "test-user") @@ -100,86 +158,359 @@ func TestJWTAuth_Authenticate(t *testing.T) { t.Run("failure without authorization header", func(t *testing.T) { t.Parallel() keySet, _, issuer := setupTestKeySet(t) + tests := []struct { + name string + auth Authenticator + }{ + { + name: "JWTAuth", + auth: NewJWTAuth(keySet, issuer, "test-service", false, []AdditionalCheck{}), + }, + { + name: "JWTAuth with additional checks", + auth: NewJWTAuth(keySet, issuer, "test-service", false, autoPassingAdditionalChecks), + }, + } - auth := NewJWTAuth(keySet, issuer, "test-service", false) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "/test", nil) - req = req.WithContext(logging.TestingContext()) + req := httptest.NewRequest("GET", "/test", nil) + req = req.WithContext(logging.TestingContext()) - authenticated, err := auth.Authenticate(nil, req) - require.Error(t, err) - require.False(t, authenticated) - require.Contains(t, err.Error(), "no authorization header") + authenticated, err := tt.auth.Authenticate(nil, req) + require.Error(t, err) + assert.Contains(t, err.Error(), "no authorization header") + assert.False(t, authenticated) + }) + } }) t.Run("failure with malformed authorization header", func(t *testing.T) { t.Parallel() keySet, _, issuer := setupTestKeySet(t) + tests := []struct { + name string + auth Authenticator + }{ + { + name: "JWTAuth", + auth: NewJWTAuth(keySet, issuer, "test-service", false, nil), + }, + { + name: "JWTAuth with additional checks", + auth: NewJWTAuth(keySet, issuer, "test-service", false, autoPassingAdditionalChecks), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Invalid token") + req = req.WithContext(logging.TestingContext()) + + authenticated, err := tt.auth.Authenticate(nil, req) + require.Error(t, err) + assert.False(t, authenticated) + assert.Contains(t, err.Error(), "malformed authorization header") + }) + } + }) - auth := NewJWTAuth(keySet, issuer, "test-service", false) + t.Run("failure with invalid token", func(t *testing.T) { + t.Parallel() + keySet, _, issuer := setupTestKeySet(t) + tests := []struct { + name string + auth Authenticator + }{ + { + name: "JWTAuth", + auth: NewJWTAuth(keySet, issuer, "test-service", false, nil), + }, + { + name: "JWTAuth with additional checks", + auth: NewJWTAuth(keySet, issuer, "test-service", false, autoPassingAdditionalChecks), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer invalid-token") + req = req.WithContext(logging.TestingContext()) + + authenticated, err := tt.auth.Authenticate(nil, req) + require.Error(t, err) + require.False(t, authenticated) + }) + } + }) - req := httptest.NewRequest("GET", "/test", nil) - req.Header.Set("Authorization", "Invalid token") - req = req.WithContext(logging.TestingContext()) + t.Run("failure with expired token", func(t *testing.T) { + t.Parallel() + keySet, privateKey, issuer := setupTestKeySet(t) + tests := []struct { + name string + auth Authenticator + }{ + { + name: "JWTAuth", + auth: NewJWTAuth(keySet, issuer, "test-service", false, nil), + }, + { + name: "JWTAuth with additional checks", + auth: NewJWTAuth(keySet, issuer, "test-service", false, autoPassingAdditionalChecks), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create an expired token + now := stdtime.Now().UTC() + expirationTime := libtime.New(now.Add(-1 * stdtime.Hour)) // Expired 1 hour ago + + accessTokenClaims := oidc.NewAccessTokenClaims( + issuer, + "test-user", + []string{"test-client"}, + expirationTime, + "test-jti", + "test-client", + ) + + signer, err := jose.NewSigner( + jose.SigningKey{ + Algorithm: jose.RS256, + Key: privateKey, + }, + (&jose.SignerOptions{}).WithHeader("kid", "test-key-id"), + ) + require.NoError(t, err) + + claimsJSON, err := json.Marshal(accessTokenClaims) + require.NoError(t, err) + + signed, err := signer.Sign(claimsJSON) + require.NoError(t, err) + + token, err := signed.CompactSerialize() + require.NoError(t, err) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + req = req.WithContext(logging.TestingContext()) + + authenticated, err := tt.auth.Authenticate(nil, req) + require.Error(t, err) + assert.False(t, authenticated) + }) + } + }) - authenticated, err := auth.Authenticate(nil, req) - require.Error(t, err) - require.False(t, authenticated) - require.Contains(t, err.Error(), "malformed authorization header") + t.Run("success with valid scopes for GET request", func(t *testing.T) { + t.Parallel() + keySet, privateKey, issuer := setupTestKeySet(t) + + tests := []struct { + name string + auth Authenticator + }{ + { + name: "JWTAuth", + auth: NewJWTAuth(keySet, issuer, "test-service", true, nil), + }, + { + name: "JWTAuth with additional checks", + auth: NewJWTAuth(keySet, issuer, "test-service", true, autoPassingAdditionalChecks), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // Create access token with read scope + token := createAccessToken(t, privateKey, issuer, []string{"test-service:read"}, "test-user") + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + req = req.WithContext(logging.TestingContext()) + + authenticated, err := tt.auth.Authenticate(nil, req) + require.NoError(t, err) + assert.True(t, authenticated) + }) + } }) - t.Run("failure with invalid token", func(t *testing.T) { + t.Run("success with write scope for POST request", func(t *testing.T) { t.Parallel() - keySet, _, issuer := setupTestKeySet(t) + keySet, privateKey, issuer := setupTestKeySet(t) - auth := NewJWTAuth(keySet, issuer, "test-service", false) + tests := []struct { + name string + auth Authenticator + }{ + { + name: "JWTAuth", + auth: NewJWTAuth(keySet, issuer, "test-service", true, nil), + }, + { + name: "JWTAuth with additional checks", + auth: NewJWTAuth(keySet, issuer, "test-service", true, autoPassingAdditionalChecks), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create access token with write scope + token := createAccessToken(t, privateKey, issuer, []string{"test-service:write"}, "test-user") + + req := httptest.NewRequest("POST", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + req = req.WithContext(logging.TestingContext()) + + authenticated, err := tt.auth.Authenticate(nil, req) + require.NoError(t, err) + assert.True(t, authenticated) + }) + } + }) - req := httptest.NewRequest("GET", "/test", nil) - req.Header.Set("Authorization", "Bearer invalid-token") - req = req.WithContext(logging.TestingContext()) + t.Run("failure with insufficient scopes for POST request", func(t *testing.T) { + t.Parallel() + keySet, privateKey, issuer := setupTestKeySet(t) - authenticated, err := auth.Authenticate(nil, req) - require.Error(t, err) - require.False(t, authenticated) + tests := []struct { + name string + auth Authenticator + }{ + { + name: "JWTAuth", + auth: NewJWTAuth(keySet, issuer, "test-service", true, nil), + }, + { + name: "JWTAuth with additional checks", + auth: NewJWTAuth(keySet, issuer, "test-service", true, autoPassingAdditionalChecks), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create access token with only read scope (not enough for POST) + token := createAccessToken(t, privateKey, issuer, []string{"test-service:read"}, "test-user") + + req := httptest.NewRequest("POST", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + req = req.WithContext(logging.TestingContext()) + + authenticated, err := tt.auth.Authenticate(nil, req) + require.Error(t, err) + assert.False(t, authenticated) + assert.Contains(t, err.Error(), "missing access") + }) + } }) - t.Run("failure with expired token", func(t *testing.T) { + t.Run("success with write scope for GET request", func(t *testing.T) { t.Parallel() keySet, privateKey, issuer := setupTestKeySet(t) - auth := NewJWTAuth(keySet, issuer, "test-service", false) - - // Create an expired token - now := stdtime.Now().UTC() - expirationTime := libtime.New(now.Add(-1 * stdtime.Hour)) // Expired 1 hour ago - - accessTokenClaims := oidc.NewAccessTokenClaims( - issuer, - "test-user", - []string{"test-client"}, - expirationTime, - "test-jti", - "test-client", - ) - - signer, err := jose.NewSigner( - jose.SigningKey{ - Algorithm: jose.RS256, - Key: privateKey, + tests := []struct { + name string + auth Authenticator + }{ + { + name: "JWTAuth", + auth: NewJWTAuth(keySet, issuer, "test-service", true, nil), }, - (&jose.SignerOptions{}).WithHeader("kid", "test-key-id"), - ) - require.NoError(t, err) + { + name: "JWTAuth with additional checks", + auth: NewJWTAuth(keySet, issuer, "test-service", true, autoPassingAdditionalChecks), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create access token with write scope + token := createAccessToken(t, privateKey, issuer, []string{"test-service:write"}, "test-user") + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + req = req.WithContext(logging.TestingContext()) + + authenticated, err := tt.auth.Authenticate(nil, req) + require.NoError(t, err) + assert.True(t, authenticated) + }) + } + }) - claimsJSON, err := json.Marshal(accessTokenClaims) - require.NoError(t, err) + t.Run("failure with different issuer", func(t *testing.T) { + t.Parallel() + keySet, privateKey, issuer := setupTestKeySet(t) + unexpectedIssuer := "https://test-issuer.differentdomain.com" + + tests := []struct { + name string + auth Authenticator + }{ + { + name: "JWTAuth", + auth: NewJWTAuth(keySet, issuer, "test-service", false, nil), + }, + { + name: "JWTAuth with additional checks", + auth: NewJWTAuth(keySet, issuer, "test-service", false, autoPassingAdditionalChecks), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create access token + token := createAccessToken(t, privateKey, unexpectedIssuer, []string{}, "test-user") + + // Create request with valid token + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + req = req.WithContext(logging.TestingContext()) + + authenticated, err := tt.auth.Authenticate(nil, req) + require.Error(t, err) + assert.False(t, authenticated) + assert.ErrorIs(t, err, oidc.ErrIssuerInvalid) + }) + } + }) - signed, err := signer.Sign(claimsJSON) - require.NoError(t, err) + t.Run("failure due to additional check", func(t *testing.T) { + t.Parallel() + keySet, privateKey, issuer := setupTestKeySet(t) - token, err := signed.CompactSerialize() - require.NoError(t, err) + var additionalChecksPerformed = 0 + + expectedErr := errors.New("expected") + autoFailingAdditionalChecks := []AdditionalCheck{ + func(*http.Request, *oidc.AccessTokenClaims) error { + additionalChecksPerformed++ + return nil + }, + func(*http.Request, *oidc.AccessTokenClaims) error { + additionalChecksPerformed++ + return expectedErr + }, + func(*http.Request, *oidc.AccessTokenClaims) error { + additionalChecksPerformed++ + return nil + }, + } + + auth := NewJWTAuth(keySet, issuer, "test-service", false, autoFailingAdditionalChecks) + // Create access token + token := createAccessToken(t, privateKey, issuer, []string{}, "test-user") + + // Create request with valid token req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) req = req.WithContext(logging.TestingContext()) @@ -187,78 +518,105 @@ func TestJWTAuth_Authenticate(t *testing.T) { authenticated, err := auth.Authenticate(nil, req) require.Error(t, err) require.False(t, authenticated) + assert.ErrorIs(t, err, expectedErr) + assert.Equal(t, 2, additionalChecksPerformed) }) - t.Run("success with valid scopes for GET request", func(t *testing.T) { + t.Run("CheckOrganizationIDClaim success with valid token and correct orgID", func(t *testing.T) { t.Parallel() keySet, privateKey, issuer := setupTestKeySet(t) + expectedOrgID := "abcdefghijkl" + + provider := func(*http.Request) (string, error) { return expectedOrgID, nil } + additionalChecks := []AdditionalCheck{ + CheckOrganizationIDClaim(provider), + } - auth := NewJWTAuth(keySet, issuer, "test-service", true) + auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks) - // Create access token with read scope - token := createAccessToken(t, privateKey, issuer, []string{"test-service:read"}, "test-user") + // Create access token + token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", expectedOrgID) + // Create request with valid token req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) req = req.WithContext(logging.TestingContext()) authenticated, err := auth.Authenticate(nil, req) require.NoError(t, err) - require.True(t, authenticated) + assert.True(t, authenticated) }) - t.Run("success with write scope for POST request", func(t *testing.T) { + t.Run("CheckOrganizationIDClaim success with valid token and no expected orgID", func(t *testing.T) { t.Parallel() keySet, privateKey, issuer := setupTestKeySet(t) - auth := NewJWTAuth(keySet, issuer, "test-service", true) + provider := func(*http.Request) (string, error) { return "", nil } + additionalChecks := []AdditionalCheck{ + CheckOrganizationIDClaim(provider), + } + auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks) - // Create access token with write scope - token := createAccessToken(t, privateKey, issuer, []string{"test-service:write"}, "test-user") + // Create access token + token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", "") - req := httptest.NewRequest("POST", "/test", nil) + // Create request with valid token + req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) req = req.WithContext(logging.TestingContext()) authenticated, err := auth.Authenticate(nil, req) require.NoError(t, err) - require.True(t, authenticated) + assert.True(t, authenticated) }) - t.Run("failure with insufficient scopes for POST request", func(t *testing.T) { + t.Run("CheckOrganizationIDClaim failure with valid token and mismatched orgID", func(t *testing.T) { t.Parallel() keySet, privateKey, issuer := setupTestKeySet(t) + expectedOrgID := "abcdefghijkl" - auth := NewJWTAuth(keySet, issuer, "test-service", true) + provider := func(*http.Request) (string, error) { return expectedOrgID, nil } + additionalChecks := []AdditionalCheck{ + CheckOrganizationIDClaim(provider), + } + auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks) - // Create access token with only read scope (not enough for POST) - token := createAccessToken(t, privateKey, issuer, []string{"test-service:read"}, "test-user") + // Create access token + token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", "someotherorgid") - req := httptest.NewRequest("POST", "/test", nil) + // Create request with valid token + req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) req = req.WithContext(logging.TestingContext()) authenticated, err := auth.Authenticate(nil, req) require.Error(t, err) - require.False(t, authenticated) - require.Contains(t, err.Error(), "missing access") + assert.ErrorIs(t, err, oidc.ErrOrgIDInvalid) + assert.False(t, authenticated) }) - t.Run("success with write scope for GET request", func(t *testing.T) { + t.Run("CheckOrganizationIDClaim failure with token that doesn't contain orgID", func(t *testing.T) { t.Parallel() keySet, privateKey, issuer := setupTestKeySet(t) + expectedOrgID := "abcdefghijkl" - auth := NewJWTAuth(keySet, issuer, "test-service", true) + provider := func(*http.Request) (string, error) { return expectedOrgID, nil } + additionalChecks := []AdditionalCheck{ + CheckOrganizationIDClaim(provider), + } + auth := NewJWTAuth(keySet, issuer, "test-service", false, additionalChecks) - // Create access token with write scope - token := createAccessToken(t, privateKey, issuer, []string{"test-service:write"}, "test-user") + // Create access token + token := createAccessTokenWithOrgClaims(t, privateKey, issuer, []string{}, "test-user", "") + // Create request with valid token req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) req = req.WithContext(logging.TestingContext()) authenticated, err := auth.Authenticate(nil, req) - require.NoError(t, err) - require.True(t, authenticated) + require.Error(t, err) + assert.ErrorIs(t, err, oidc.ErrOrgIDNotPresent) + assert.False(t, authenticated) }) } diff --git a/auth/authenticator_generated.go b/auth/authenticator_generated.go new file mode 100644 index 00000000..ed97551c --- /dev/null +++ b/auth/authenticator_generated.go @@ -0,0 +1,56 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: middleware.go +// +// Generated by this command: +// +// mockgen -source middleware.go -destination authenticator_generated.go -package auth . Authenticator +// + +// Package auth is a generated GoMock package. +package auth + +import ( + http "net/http" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockAuthenticator is a mock of Authenticator interface. +type MockAuthenticator struct { + ctrl *gomock.Controller + recorder *MockAuthenticatorMockRecorder + isgomock struct{} +} + +// MockAuthenticatorMockRecorder is the mock recorder for MockAuthenticator. +type MockAuthenticatorMockRecorder struct { + mock *MockAuthenticator +} + +// NewMockAuthenticator creates a new mock instance. +func NewMockAuthenticator(ctrl *gomock.Controller) *MockAuthenticator { + mock := &MockAuthenticator{ctrl: ctrl} + mock.recorder = &MockAuthenticatorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAuthenticator) EXPECT() *MockAuthenticatorMockRecorder { + return m.recorder +} + +// Authenticate mocks base method. +func (m *MockAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Authenticate", w, r) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Authenticate indicates an expected call of Authenticate. +func (mr *MockAuthenticatorMockRecorder) Authenticate(w, r any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Authenticate", reflect.TypeOf((*MockAuthenticator)(nil).Authenticate), w, r) +} diff --git a/auth/cli.go b/auth/cli.go index de4e2240..1896b272 100644 --- a/auth/cli.go +++ b/auth/cli.go @@ -22,18 +22,35 @@ func AddFlags(flags *flag.FlagSet) { flags.String(AuthServiceFlag, "", "Service") } -func FXModuleFromFlags(cmd *cobra.Command) fx.Option { +func defaultModuleConfig(cmd *cobra.Command) ModuleConfig { authEnabled, _ := cmd.Flags().GetBool(AuthEnabledFlag) authIssuer, _ := cmd.Flags().GetString(AuthIssuerFlag) authReadKeySetMaxRetries, _ := cmd.Flags().GetInt(AuthReadKeySetMaxRetriesFlag) authCheckScopes, _ := cmd.Flags().GetBool(AuthCheckScopesFlag) authService, _ := cmd.Flags().GetString(AuthServiceFlag) - return Module(ModuleConfig{ + return ModuleConfig{ Enabled: authEnabled, Issuer: authIssuer, ReadKeySetMaxRetries: authReadKeySetMaxRetries, CheckScopes: authCheckScopes, Service: authService, - }) + AdditionalChecks: make([]AdditionalCheck, 0), + } +} + +func FXModuleFromFlags(cmd *cobra.Command) fx.Option { + return Module(defaultModuleConfig(cmd)) +} + +func OrganizationAwareFXModuleFromFlags(cmd *cobra.Command, fn OrganizationIDProvider) fx.Option { + cfg := defaultModuleConfig(cmd) + cfg.AdditionalChecks = append(cfg.AdditionalChecks, CheckOrganizationIDClaim(fn)) + return Module(cfg) +} + +func AdditionalChecksFXModuleFromFlags(cmd *cobra.Command, checks ...AdditionalCheck) fx.Option { + cfg := defaultModuleConfig(cmd) + cfg.AdditionalChecks = append(cfg.AdditionalChecks, checks...) + return Module(cfg) } diff --git a/auth/middleware.go b/auth/middleware.go index 304df881..45f61199 100644 --- a/auth/middleware.go +++ b/auth/middleware.go @@ -1,9 +1,13 @@ package auth import ( + "errors" "net/http" + + "github.com/formancehq/go-libs/v3/oidc" ) +//go:generate mockgen -source middleware.go -destination authenticator_generated.go -package auth . Authenticator type Authenticator interface { Authenticate(w http.ResponseWriter, r *http.Request) (bool, error) } @@ -13,6 +17,12 @@ func Middleware(ja Authenticator) func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authenticated, err := ja.Authenticate(w, r) if err != nil { + // client is authenticated but doesn't have permission to access this resource + if errors.Is(err, oidc.ErrOrgIDNotPresent) || errors.Is(err, oidc.ErrOrgIDInvalid) { + w.WriteHeader(http.StatusForbidden) + return + } + w.WriteHeader(http.StatusUnauthorized) return } diff --git a/auth/middleware_test.go b/auth/middleware_test.go index f0b8385a..37f3e516 100644 --- a/auth/middleware_test.go +++ b/auth/middleware_test.go @@ -1,13 +1,16 @@ package auth import ( + "fmt" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/require" + gomock "go.uber.org/mock/gomock" "github.com/formancehq/go-libs/v3/logging" + "github.com/formancehq/go-libs/v3/oidc" ) func TestMiddleware(t *testing.T) { @@ -17,7 +20,7 @@ func TestMiddleware(t *testing.T) { t.Parallel() keySet, privateKey, issuer := setupTestKeySet(t) - authenticator := NewJWTAuth(keySet, issuer, "test-service", false) + authenticator := NewJWTAuth(keySet, issuer, "test-service", false, nil) // Create access token token := createAccessToken(t, privateKey, issuer, []string{}, "test-user") @@ -42,7 +45,7 @@ func TestMiddleware(t *testing.T) { t.Parallel() keySet, _, issuer := setupTestKeySet(t) - authenticator := NewJWTAuth(keySet, issuer, "test-service", false) + authenticator := NewJWTAuth(keySet, issuer, "test-service", false, nil) handler := Middleware(authenticator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -57,4 +60,43 @@ func TestMiddleware(t *testing.T) { require.Equal(t, http.StatusUnauthorized, rr.Code) }) + + t.Run("forbidden", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + authError error + }{ + { + name: "Invalid OrgID", + authError: fmt.Errorf("err: %w", oidc.ErrOrgIDInvalid), + }, + { + name: "OrgID missing from token", + authError: fmt.Errorf("err: %w", oidc.ErrOrgIDNotPresent), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + authenticator := NewMockAuthenticator(ctrl) + + handler := Middleware(authenticator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req.Header.Set("Authorization", "Bearer mock-token") + req = req.WithContext(logging.TestingContext()) + + authenticator.EXPECT().Authenticate(gomock.Any(), gomock.Any()).Return(true, tt.authError) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + require.Equal(t, http.StatusForbidden, rr.Code) + }) + } + }) } diff --git a/auth/module.go b/auth/module.go index 316f4398..4d2dc068 100644 --- a/auth/module.go +++ b/auth/module.go @@ -17,46 +17,52 @@ type ModuleConfig struct { ReadKeySetMaxRetries int CheckScopes bool Service string + + AdditionalChecks []AdditionalCheck } func Module(cfg ModuleConfig) fx.Option { options := make([]fx.Option, 0) - if cfg.Enabled { - options = append(options, - fx.Supply(http.DefaultClient), - fx.Provide(func(httpClient *http.Client) (oidc.KeySet, error) { - retryableHttpClient := retryablehttp.NewClient() - retryableHttpClient.RetryMax = cfg.ReadKeySetMaxRetries - retryableHttpClient.HTTPClient = httpClient - - discovery, err := client.Discover[oidc.DiscoveryConfiguration]( - context.Background(), - cfg.Issuer, - retryableHttpClient.StandardClient(), - ) - if err != nil { - return nil, err - } - - return client.NewRemoteKeySet(httpClient, discovery.JwksURI), nil - }), - fx.Provide(func(keySet oidc.KeySet) Authenticator { - return NewJWTAuth( - keySet, - cfg.Issuer, - cfg.Service, - cfg.CheckScopes, - ) - }), - ) - } else { + if !cfg.Enabled { options = append(options, fx.Provide(func() Authenticator { return NewNoAuth() }), ) + return fx.Module("auth", options...) } + options = append(options, + fx.Supply(http.DefaultClient), + fx.Provide(func(httpClient *http.Client) (oidc.KeySet, error) { + retryableHttpClient := retryablehttp.NewClient() + retryableHttpClient.RetryMax = cfg.ReadKeySetMaxRetries + retryableHttpClient.HTTPClient = httpClient + + discovery, err := client.Discover[oidc.DiscoveryConfiguration]( + context.Background(), + cfg.Issuer, + retryableHttpClient.StandardClient(), + ) + if err != nil { + return nil, err + } + + return client.NewRemoteKeySet(httpClient, discovery.JwksURI), nil + }), + ) + + options = append(options, + fx.Provide(func(keySet oidc.KeySet) Authenticator { + return NewJWTAuth( + keySet, + cfg.Issuer, + cfg.Service, + cfg.CheckScopes, + cfg.AdditionalChecks, + ) + }), + ) return fx.Module("auth", options...) } diff --git a/auth/module_test.go b/auth/module_test.go index c600b493..79b2eff0 100644 --- a/auth/module_test.go +++ b/auth/module_test.go @@ -99,6 +99,53 @@ func TestModule(t *testing.T) { } }) + t.Run("module with additional checks calls discovery endpoint when enabled", func(t *testing.T) { + t.Parallel() + _, issuer, discoveryCalled := setupTestOIDCServer(t) + + var authenticator auth.Authenticator + + provider := func(*http.Request) (string, error) { return "dummy", nil } + additionalChecks := []auth.AdditionalCheck{ + auth.CheckOrganizationIDClaim(provider), + } + + options := []fx.Option{ + auth.Module(auth.ModuleConfig{ + Enabled: true, + Issuer: issuer, + Service: "test-service-with-orgId-aware-auth", + CheckScopes: false, + AdditionalChecks: additionalChecks, + }), + fx.Provide(func() context.Context { + return context.Background() + }), + fx.Provide(func() logging.Logger { + return logging.Testing() + }), + fx.Populate(&authenticator), + } + + if !testing.Verbose() { + options = append(options, fx.NopLogger) + } + + app := fxtest.New(t, options...) + app.RequireStart() + defer app.RequireStop() + + require.NotNil(t, authenticator) + + // Verify that the discovery endpoint was called + select { + case called := <-discoveryCalled: + require.True(t, called, "Discovery endpoint should have been called") + default: + t.Fatal("Discovery endpoint was not called") + } + }) + t.Run("module can be overridden with fx.Decorate", func(t *testing.T) { t.Parallel() diff --git a/auth/scopes.go b/auth/scopes.go new file mode 100644 index 00000000..0b2b5930 --- /dev/null +++ b/auth/scopes.go @@ -0,0 +1,26 @@ +package auth + +import ( + "fmt" + "net/http" + "strings" + + "github.com/formancehq/go-libs/v3/collectionutils" + "github.com/formancehq/go-libs/v3/oidc" +) + +func checkScopes(service string, method string, scopes oidc.SpaceDelimitedArray) (bool, error) { + allowed := true //nolint:ineffassign + switch method { + case http.MethodOptions, http.MethodGet, http.MethodHead, http.MethodTrace: + allowed = collectionutils.Contains(scopes, service+":read") || + collectionutils.Contains(scopes, service+":write") + default: + allowed = collectionutils.Contains(scopes, service+":write") + } + + if !allowed { + return false, fmt.Errorf("missing access, found scopes: '%s' need %s:read|write", strings.Join(scopes, ", "), service) + } + return true, nil +} diff --git a/oidc/organization_aware_access_token_claims.go b/oidc/organization_aware_access_token_claims.go new file mode 100644 index 00000000..cf2581a0 --- /dev/null +++ b/oidc/organization_aware_access_token_claims.go @@ -0,0 +1,27 @@ +package oidc + +import "github.com/formancehq/go-libs/v3/time" + +const ClaimOrganizationID = "organization_id" + +// Convenience wrapper for fetching orgID from custom claims +type OrganizationAwareAccessTokenClaims struct { + AccessTokenClaims +} + +func NewOrganizationAwareAccessTokenClaims(issuer, subject string, audience []string, expiration time.Time, jwtid, clientID string) *OrganizationAwareAccessTokenClaims { + atc := NewAccessTokenClaims(issuer, subject, audience, expiration, jwtid, clientID) + return &OrganizationAwareAccessTokenClaims{*atc} +} + +func (o *OrganizationAwareAccessTokenClaims) GetOrganizationID() string { + val, ok := o.Claims[ClaimOrganizationID] + if !ok { + return "" + } + + if orgID, ok := val.(string); ok { + return orgID + } + return "" +} diff --git a/oidc/verifier.go b/oidc/verifier.go index f8a71f4e..27cbcca4 100644 --- a/oidc/verifier.go +++ b/oidc/verifier.go @@ -54,6 +54,9 @@ var ( ErrAuthTimeNotPresent = errors.New("claim `auth_time` of token is missing") ErrAuthTimeToOld = errors.New("auth time of token is too old") ErrAtHash = errors.New("at_hash does not correspond to access token") + + ErrOrgIDNotPresent = errors.New("claim `organization_id` of token is missing") + ErrOrgIDInvalid = errors.New("organization does not match") ) // Verifier caries configuration for the various token verification